diff options
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 74 |
1 files changed, 73 insertions, 1 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c058ac6e..ec7d14cb 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -11,7 +11,7 @@ from modules.shared import opts, device, cmd_opts from ldm.util import default
from einops import rearrange
import ldm.modules.attention
-
+import ldm.modules.diffusionmodules.model
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
@@ -100,6 +100,76 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2)
+def nonlinearity_hijack(x):
+ # swish
+ t = torch.sigmoid(x)
+ x *= t
+ del t
+
+ return x
+
+def cross_attention_attnblock_forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q1 = self.q(h_)
+ k1 = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q1.shape
+
+ q2 = q1.reshape(b, c, h*w)
+ del q1
+
+ q = q2.permute(0, 2, 1) # b,hw,c
+ del q2
+
+ k = k1.reshape(b, c, h*w) # b,c,hw
+ del k1
+
+ h_ = torch.zeros_like(k, device=q.device)
+
+ stats = torch.cuda.memory_stats(q.device)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_cuda + mem_free_torch
+
+ tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
+ mem_required = tensor_size * 2.5
+ steps = 1
+
+ if mem_required > mem_free_total:
+ steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
+
+ slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
+ for i in range(0, q.shape[1], slice_size):
+ end = i + slice_size
+
+ w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w2 = w1 * (int(c)**(-0.5))
+ del w1
+ w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
+ del w2
+
+ # attend to values
+ v1 = v.reshape(b, c, h*w)
+ w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ del w3
+
+ h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ del v1, w4
+
+ h2 = h_.reshape(b, c, h, w)
+ del h_
+
+ h3 = self.proj_out(h2)
+ del h2
+
+ h3 += x
+
+ return h3
class StableDiffusionModelHijack:
ids_lookup = {}
@@ -175,6 +245,8 @@ class StableDiffusionModelHijack: if cmd_opts.opt_split_attention:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
|