diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/devices.py | 9 | ||||
-rw-r--r-- | modules/import_hook.py | 5 | ||||
-rw-r--r-- | modules/safe.py | 12 | ||||
-rw-r--r-- | modules/sd_hijack_optimizations.py | 10 |
4 files changed, 26 insertions, 10 deletions
diff --git a/modules/devices.py b/modules/devices.py index f8cffae1..800510b7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -125,7 +125,16 @@ def layer_norm_fix(*args, **kwargs): return orig_layer_norm(*args, **kwargs) +# MPS workaround for https://github.com/pytorch/pytorch/issues/90532 +orig_tensor_numpy = torch.Tensor.numpy +def numpy_fix(self, *args, **kwargs): + if self.requires_grad: + self = self.detach() + return orig_tensor_numpy(self, *args, **kwargs) + + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): torch.Tensor.to = tensor_to_fix torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix diff --git a/modules/import_hook.py b/modules/import_hook.py new file mode 100644 index 00000000..28c67dfa --- /dev/null +++ b/modules/import_hook.py @@ -0,0 +1,5 @@ +import sys + +# this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it +if "--xformers" not in "".join(sys.argv): + sys.modules["xformers"] = None diff --git a/modules/safe.py b/modules/safe.py index 10460ad0..7c89c4c2 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -37,16 +37,16 @@ class RestrictedUnpickler(pickle.Unpickler): if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
- if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
+ if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter', '_rebuild_device_tensor_from_numpy']:
return getattr(torch._utils, name)
- if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
+ if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage', 'float32']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
- if module == 'numpy.core.multiarray' and name == 'scalar':
- return numpy.core.multiarray.scalar
- if module == 'numpy' and name == 'dtype':
- return numpy.dtype
+ if module == 'numpy.core.multiarray' and name in ['scalar', '_reconstruct']:
+ return getattr(numpy.core.multiarray, name)
+ if module == 'numpy' and name in ['dtype', 'ndarray']:
+ return getattr(numpy, name)
if module == '_codecs' and name == 'encode':
return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 98123fbf..02c87f40 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -127,7 +127,7 @@ def check_for_psutil(): invokeAI_mps_available = check_for_psutil()
-# -- Taken from https://github.com/invoke-ai/InvokeAI --
+# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available:
import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
@@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size): return r
def einsum_op_mps_v1(q, k, v):
- if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+ if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ if slice_size % 4096 == 0:
+ slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(q, k, v):
- if mem_total_gb > 8 and q.shape[1] <= 4096:
+ if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v)
else:
return einsum_op_slice_0(q, k, v, 1)
@@ -188,7 +190,7 @@ def einsum_op(q, k, v): return einsum_op_cuda(q, k, v)
if q.device.type == 'mps':
- if mem_total_gb >= 32:
+ if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
return einsum_op_mps_v1(q, k, v)
return einsum_op_mps_v2(q, k, v)
|