aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sub_quadratic_attention.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-01-25 05:23:10 +0000
committerbrkirch <brkirch@users.noreply.github.com>2023-01-25 06:13:04 +0000
commite3b53fd295aca784253dfc8668ec87b537a72f43 (patch)
tree6fb26afd730c0561a2506ead2d2c8295d326de40 /modules/sub_quadratic_attention.py
parent84d9ce30cb427759547bc7876ed80ab91787d175 (diff)
downloadstable-diffusion-webui-gfx803-e3b53fd295aca784253dfc8668ec87b537a72f43.tar.gz
stable-diffusion-webui-gfx803-e3b53fd295aca784253dfc8668ec87b537a72f43.tar.bz2
stable-diffusion-webui-gfx803-e3b53fd295aca784253dfc8668ec87b537a72f43.zip
Add UI setting for upcasting attention to float32
Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers. In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also.
Diffstat (limited to 'modules/sub_quadratic_attention.py')
-rw-r--r--modules/sub_quadratic_attention.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index 55052815..05595323 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -67,7 +67,7 @@ def _summarize_chunk(
max_score, _ = torch.max(attn_weights, -1, keepdim=True)
max_score = max_score.detach()
exp_weights = torch.exp(attn_weights - max_score)
- exp_values = torch.bmm(exp_weights, value)
+ exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
@@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking(
)
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
- hidden_states_slice = torch.bmm(attn_probs, value)
+ hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
return hidden_states_slice