diff options
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r-- | modules/sd_hijack_optimizations.py | 60 |
1 files changed, 56 insertions, 4 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index d9cca485..d23d733b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,4 +1,7 @@ import math
+import sys
+import traceback
+
import torch
from torch import einsum
@@ -7,18 +10,37 @@ from einops import rearrange from modules import shared
+if shared.cmd_opts.xformers:
+ try:
+ import xformers.ops
+ import functorch
+ xformers._is_functorch_available = True
+ shared.xformers_available = True
+ except Exception:
+ print("Cannot import xformers", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
- q = self.to_q(x)
+ q_in = self.to_q(x)
context = default(context, x)
- k = self.to_k(context)
- v = self.to_v(context)
+
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is not None:
+ k_in = self.to_k(hypernetwork_layers[0](context))
+ v_in = self.to_v(hypernetwork_layers[1](context))
+ else:
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
del context, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
@@ -31,6 +53,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
+ del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
@@ -105,6 +128,25 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2)
+def xformers_attention_forward(self, x, context=None, mask=None):
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+ if hypernetwork_layers is not None:
+ k_in = self.to_k(hypernetwork_layers[0](context))
+ v_in = self.to_v(hypernetwork_layers[1](context))
+ else:
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
+
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
+ return self.to_out(out)
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
@@ -167,3 +209,13 @@ def cross_attention_attnblock_forward(self, x): h3 += x
return h3
+
+def xformers_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_).contiguous()
+ k1 = self.k(h_).contiguous()
+ v = self.v(h_).contiguous()
+ out = xformers.ops.memory_efficient_attention(q1, k1, v)
+ out = self.proj_out(out)
+ return x+out
|