aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-13 06:30:33 +0000
committerAUTOMATIC1111 <16777216c@gmail.com>2023-07-13 06:30:33 +0000
commitac4ccfa1369e74492b467294eab96c3f558b297b (patch)
tree357cc16bc3a82c2c4a07c410e3cd4e83e5e3b6f9 /modules/sd_hijack_optimizations.py
parentb717eb7e56a4e620e77a2225e80223c89cb4f0d1 (diff)
downloadstable-diffusion-webui-gfx803-ac4ccfa1369e74492b467294eab96c3f558b297b.tar.gz
stable-diffusion-webui-gfx803-ac4ccfa1369e74492b467294eab96c3f558b297b.tar.bz2
stable-diffusion-webui-gfx803-ac4ccfa1369e74492b467294eab96c3f558b297b.zip
get attention optimizations to work
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index e99c9ba5..b5f85ba5 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -173,7 +173,7 @@ def get_available_vram():
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
-def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
@@ -214,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None, additiona
# taken from https://github.com/Doggettx/stable-diffusion and modified
-def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
@@ -355,7 +355,7 @@ def einsum_op(q, k, v):
return einsum_op_tensor_mem(q, k, v, 32)
-def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):
h = self.heads
q = self.to_q(x)
@@ -383,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, add
# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1
# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface
-def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+def sub_quad_attention_forward(self, x, context=None, mask=None, **kwargs):
assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."
h = self.heads
@@ -470,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v):
return None
-def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+def xformers_attention_forward(self, x, context=None, mask=None, **kwargs):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
@@ -496,7 +496,7 @@ def xformers_attention_forward(self, x, context=None, mask=None, additional_toke
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
-def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+def scaled_dot_product_attention_forward(self, x, context=None, mask=None, **kwargs):
batch_size, sequence_length, inner_dim = x.shape
if mask is not None:
@@ -537,7 +537,7 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None, addit
return hidden_states
-def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
+def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, **kwargs):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return scaled_dot_product_attention_forward(self, x, context, mask)