diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-10-08 13:29:59 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-08 13:29:59 +0000 |
commit | 48feae37ff36915df9a3502a0a5aa1b7f146ab14 (patch) | |
tree | 698e1ad35b2ee6d528ece2f9cdadc96aa22d3e54 /modules/sd_hijack_optimizations.py | |
parent | 5f85a74b00c0154bfd559dc67edfa7e30342b7c9 (diff) | |
parent | 970de9ee6891ff586821d0d80dde01c2f6c681b3 (diff) | |
download | stable-diffusion-webui-gfx803-48feae37ff36915df9a3502a0a5aa1b7f146ab14.tar.gz stable-diffusion-webui-gfx803-48feae37ff36915df9a3502a0a5aa1b7f146ab14.tar.bz2 stable-diffusion-webui-gfx803-48feae37ff36915df9a3502a0a5aa1b7f146ab14.zip |
Merge pull request #1851 from C43H66N12O12S2/flash
xformers attention
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r-- | modules/sd_hijack_optimizations.py | 38 |
1 files changed, 37 insertions, 1 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 3351c740..e43e2c7a 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,14 @@ import math
import torch
from torch import einsum
-
+try:
+ import xformers.ops
+ import functorch
+ xformers._is_functorch_available = True
+ shared.xformers_available = True
+except:
+ print('Cannot find xformers, defaulting to split attention. Try setting --xformers in your webui-user file if you wish to install it.')
+ continue
from ldm.util import default
from einops import rearrange
@@ -115,6 +122,25 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2)
+def xformers_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+ if hypernetwork_layers is not None:
+ k_in = self.to_k(hypernetwork_layers[0](context))
+ v_in = self.to_v(hypernetwork_layers[1](context))
+ else:
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
+
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
+ return self.to_out(out)
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -177,3 +203,13 @@ def cross_attention_attnblock_forward(self, x): h3 += x
return h3
+
+def xformers_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_).contiguous()
+ k1 = self.k(h_).contiguous()
+ v = self.v(h_).contiguous()
+ out = xformers.ops.memory_efficient_attention(q1, k1, v)
+ out = self.proj_out(out)
+ return x+out
|