aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-20 19:29:51 +0000
committerAUTOMATIC <16777216c@gmail.com>2023-05-20 19:29:51 +0000
commit05e6fc9aa944dd6e3ee01eae0817f8b51134ffab (patch)
treeee6c49a0f4f3f33ac8a9dac8938138f68a867f4a /modules/sd_hijack_optimizations.py
parentcc6c0fc70a8fee1ea01a5e1a63d4edd645b26687 (diff)
parent2140bd1c108dd17bbf8601b10da7865ed1ac1607 (diff)
downloadstable-diffusion-webui-gfx803-05e6fc9aa944dd6e3ee01eae0817f8b51134ffab.tar.gz
stable-diffusion-webui-gfx803-05e6fc9aa944dd6e3ee01eae0817f8b51134ffab.tar.bz2
stable-diffusion-webui-gfx803-05e6fc9aa944dd6e3ee01eae0817f8b51134ffab.zip
Merge branch 'ui-selection-for-cross-attention-optimization' into dev
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py125
1 files changed, 122 insertions, 3 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index f00fe55c..0eb4c525 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -9,10 +9,129 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared, errors, devices
+from modules import shared, errors, devices, sub_quadratic_attention
from modules.hypernetworks import hypernetwork
-from .sub_quadratic_attention import efficient_dot_product_attention
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.model
+
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
+
+class SdOptimization:
+ name: str = None
+ label: str | None = None
+ cmd_opt: str | None = None
+ priority: int = 0
+
+ def title(self):
+ if self.label is None:
+ return self.name
+
+ return f"{self.name} - {self.label}"
+
+ def is_available(self):
+ return True
+
+ def apply(self):
+ pass
+
+ def undo(self):
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
+
+class SdOptimizationXformers(SdOptimization):
+ name = "xformers"
+ cmd_opt = "xformers"
+ priority = 100
+
+ def is_available(self):
+ return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
+
+
+class SdOptimizationSdpNoMem(SdOptimization):
+ name = "sdp-no-mem"
+ label = "scaled dot product without memory efficient attention"
+ cmd_opt = "opt_sdp_no_mem_attention"
+ priority = 90
+
+ def is_available(self):
+ return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
+
+
+class SdOptimizationSdp(SdOptimizationSdpNoMem):
+ name = "sdp"
+ label = "scaled dot product"
+ cmd_opt = "opt_sdp_attention"
+ priority = 80
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
+
+
+class SdOptimizationSubQuad(SdOptimization):
+ name = "sub-quadratic"
+ cmd_opt = "opt_sub_quad_attention"
+ priority = 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
+
+
+class SdOptimizationV1(SdOptimization):
+ name = "V1"
+ label = "original v1"
+ cmd_opt = "opt_split_attention_v1"
+ priority = 10
+
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
+
+
+class SdOptimizationInvokeAI(SdOptimization):
+ name = "InvokeAI"
+ cmd_opt = "opt_split_attention_invokeai"
+
+ @property
+ def priority(self):
+ return 1000 if not torch.cuda.is_available() else 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
+
+
+class SdOptimizationDoggettx(SdOptimization):
+ name = "Doggettx"
+ cmd_opt = "opt_split_attention"
+ priority = 20
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+
+
+def list_optimizers(res):
+ res.extend([
+ SdOptimizationXformers(),
+ SdOptimizationSdpNoMem(),
+ SdOptimizationSdp(),
+ SdOptimizationSubQuad(),
+ SdOptimizationV1(),
+ SdOptimizationInvokeAI(),
+ SdOptimizationDoggettx(),
+ ])
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
@@ -299,7 +418,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
- return efficient_dot_product_attention(
+ return sub_quadratic_attention.efficient_dot_product_attention(
q,
k,
v,