aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py15
1 files changed, 7 insertions, 8 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index f10865cd..a174bbe1 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -49,7 +49,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
v_in = self.to_v(context_v)
del context, context_k, context_v, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -98,7 +98,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
del context, x
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
@@ -229,7 +229,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
with devices.without_autocast(disable=not shared.opts.upcast_attn):
k = k * self.scale
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+ q, k, v = (rearrange(t, 'b n (h d) -> (b h) n d', h=h) for t in (q, k, v))
r = einsum_op(q, k, v)
r = r.to(dtype)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
@@ -296,7 +296,6 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
- query_chunk_size = q_tokens
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
@@ -335,7 +334,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ q, k, v = (rearrange(t, 'b n (h d) -> b n h d', h=h) for t in (q_in, k_in, v_in))
del q_in, k_in, v_in
dtype = q.dtype
@@ -461,7 +460,7 @@ def xformers_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -483,7 +482,7 @@ def sdp_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
@@ -507,7 +506,7 @@ def sub_quad_attnblock_forward(self, x):
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
- q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
+ q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()