diff options
author | brkirch <brkirch@users.noreply.github.com> | 2022-12-19 22:25:14 +0000 |
---|---|---|
committer | brkirch <brkirch@users.noreply.github.com> | 2022-12-21 02:30:00 +0000 |
commit | 35b1775b32a07f1b7c9dccad61f7aa77027a00fa (patch) | |
tree | 59f6583295d24b8f3c985da47d16089dfc5eeee4 /modules | |
parent | 685f9631b56ff8bd43bce24ff5ce0f9a0e9af490 (diff) | |
download | stable-diffusion-webui-gfx803-35b1775b32a07f1b7c9dccad61f7aa77027a00fa.tar.gz stable-diffusion-webui-gfx803-35b1775b32a07f1b7c9dccad61f7aa77027a00fa.tar.bz2 stable-diffusion-webui-gfx803-35b1775b32a07f1b7c9dccad61f7aa77027a00fa.zip |
Use other MPS optimization for large q.shape[0] * q.shape[1]
Check if q.shape[0] * q.shape[1] is 2**18 or larger and use the lower memory usage MPS optimization if it is. This should prevent most crashes that were occurring at certain resolutions (e.g. 1024x1024, 2048x512, 512x2048).
Also included is a change to check slice_size and prevent it from being divisible by 4096 which also results in a crash. Otherwise a crash can occur at 1024x512 or 512x1024 resolution.
Diffstat (limited to 'modules')
-rw-r--r-- | modules/sd_hijack_optimizations.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 98123fbf..02c87f40 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -127,7 +127,7 @@ def check_for_psutil(): invokeAI_mps_available = check_for_psutil()
-# -- Taken from https://github.com/invoke-ai/InvokeAI --
+# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available:
import psutil
mem_total_gb = psutil.virtual_memory().total // (1 << 30)
@@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size): return r
def einsum_op_mps_v1(q, k, v):
- if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
+ if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
return einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
+ if slice_size % 4096 == 0:
+ slice_size -= 1
return einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(q, k, v):
- if mem_total_gb > 8 and q.shape[1] <= 4096:
+ if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
return einsum_op_compvis(q, k, v)
else:
return einsum_op_slice_0(q, k, v, 1)
@@ -188,7 +190,7 @@ def einsum_op(q, k, v): return einsum_op_cuda(q, k, v)
if q.device.type == 'mps':
- if mem_total_gb >= 32:
+ if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
return einsum_op_mps_v1(q, k, v)
return einsum_op_mps_v2(q, k, v)
|