diff options
author | Alexander Ljungberg <aljungberg@wireload.net> | 2023-06-06 20:45:30 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-06 20:45:30 +0000 |
commit | d9cc0910c8aca481f294009526897152901c32b9 (patch) | |
tree | c850c5e4316b1cf830224dcae01aa5aecff589a3 | |
parent | baf6946e06249c5af9851c60171692c44ef633e0 (diff) | |
download | stable-diffusion-webui-gfx803-d9cc0910c8aca481f294009526897152901c32b9.tar.gz stable-diffusion-webui-gfx803-d9cc0910c8aca481f294009526897152901c32b9.tar.bz2 stable-diffusion-webui-gfx803-d9cc0910c8aca481f294009526897152901c32b9.zip |
Fix upcast attention dtype error.
Without this fix, enabling the "Upcast cross attention layer to float32" option while also using `--opt-sdp-attention` breaks generation with an error:
```
File "/ext3/automatic1111/stable-diffusion-webui/modules/sd_hijack_optimizations.py", line 612, in sdp_attnblock_forward
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::Half instead.
```
The fix is to make sure to upcast the value tensor too.
-rw-r--r-- | modules/sd_hijack_optimizations.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 80e48a42..6464ca8e 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -605,7 +605,7 @@ def sdp_attnblock_forward(self, x): q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
- q, k = q.float(), k.float()
+ q, k, v = q.float(), k.float(), v.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
|