diff options
author | C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> | 2022-09-12 13:48:21 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-12 13:48:21 +0000 |
commit | aaea8b4494b09cb3cec0e85ec2e17e25a692520f (patch) | |
tree | 9c32b5b1abf2d15d68204ac5e8de9617aeadab93 /modules/sd_hijack.py | |
parent | a5a760a7d46781df42adb003642d46bd9136496e (diff) | |
download | stable-diffusion-webui-gfx803-aaea8b4494b09cb3cec0e85ec2e17e25a692520f.tar.gz stable-diffusion-webui-gfx803-aaea8b4494b09cb3cec0e85ec2e17e25a692520f.tar.bz2 stable-diffusion-webui-gfx803-aaea8b4494b09cb3cec0e85ec2e17e25a692520f.zip |
Update cross attention to the newest version
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 9eb6cc20..c058ac6e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -67,8 +67,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None): mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
- tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
- mem_required = tensor_size * 2.5
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
+ modifier = 3 if q.element_size() == 2 else 2.5
+ mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
@@ -86,7 +87,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
- s2 = s1.softmax(dim=-1)
+ s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|