diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-05-18 19:48:28 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-05-18 19:48:28 +0000 |
commit | 2582a0fd3b3e91c5fba9e5e561cbdf5fee835063 (patch) | |
tree | 656cfdfe0b8eb4ed40ce03581d4abc2d4af0e6dd /modules/sd_hijack_optimizations.py | |
parent | 2e006fa50046440c81e663d0e833a85b35258d41 (diff) | |
download | stable-diffusion-webui-gfx803-2582a0fd3b3e91c5fba9e5e561cbdf5fee835063.tar.gz stable-diffusion-webui-gfx803-2582a0fd3b3e91c5fba9e5e561cbdf5fee835063.tar.bz2 stable-diffusion-webui-gfx803-2582a0fd3b3e91c5fba9e5e561cbdf5fee835063.zip |
make it possible for scripts to add cross attention optimizations
add UI selection for cross attention optimization
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r-- | modules/sd_hijack_optimizations.py | 135 |
1 files changed, 132 insertions, 3 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f00fe55c..1c5b709b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,10 +9,139 @@ from torch import einsum from ldm.util import default
from einops import rearrange
-from modules import shared, errors, devices
+from modules import shared, errors, devices, sub_quadratic_attention, script_callbacks
from modules.hypernetworks import hypernetwork
-from .sub_quadratic_attention import efficient_dot_product_attention
+import ldm.modules.attention
+import ldm.modules.diffusionmodules.model
+
+diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
+
+
+class SdOptimization:
+ def __init__(self, name, label=None, cmd_opt=None):
+ self.name = name
+ self.label = label
+ self.cmd_opt = cmd_opt
+
+ def title(self):
+ if self.label is None:
+ return self.name
+
+ return f"{self.name} - {self.label}"
+
+ def is_available(self):
+ return True
+
+ def priority(self):
+ return 0
+
+ def apply(self):
+ pass
+
+ def undo(self):
+ ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
+
+class SdOptimizationXformers(SdOptimization):
+ def __init__(self):
+ super().__init__("xformers", cmd_opt="xformers")
+
+ def is_available(self):
+ return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))
+
+ def priority(self):
+ return 100
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward
+
+
+class SdOptimizationSdpNoMem(SdOptimization):
+ def __init__(self, name="sdp-no-mem", label="scaled dot product without memory efficient attention", cmd_opt="opt_sdp_no_mem_attention"):
+ super().__init__(name, label, cmd_opt)
+
+ def is_available(self):
+ return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)
+
+ def priority(self):
+ return 90
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward
+
+
+class SdOptimizationSdp(SdOptimizationSdpNoMem):
+ def __init__(self):
+ super().__init__("sdp", "scaled dot product", cmd_opt="opt_sdp_attention")
+
+ def priority(self):
+ return 80
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward
+
+
+class SdOptimizationSubQuad(SdOptimization):
+ def __init__(self):
+ super().__init__("sub-quadratic", cmd_opt="opt_sub_quad_attention")
+
+ def priority(self):
+ return 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward
+
+
+class SdOptimizationV1(SdOptimization):
+ def __init__(self):
+ super().__init__("V1", "original v1", cmd_opt="opt_split_attention_v1")
+
+ def priority(self):
+ return 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
+
+
+class SdOptimizationInvokeAI(SdOptimization):
+ def __init__(self):
+ super().__init__("InvokeAI", cmd_opt="opt_split_attention_invokeai")
+
+ def priority(self):
+ return 1000 if not torch.cuda.is_available() else 10
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI
+
+
+class SdOptimizationDoggettx(SdOptimization):
+ def __init__(self):
+ super().__init__("Doggettx", cmd_opt="opt_split_attention")
+
+ def priority(self):
+ return 20
+
+ def apply(self):
+ ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
+
+
+def list_optimizers(res):
+ res.extend([
+ SdOptimizationXformers(),
+ SdOptimizationSdpNoMem(),
+ SdOptimizationSdp(),
+ SdOptimizationSubQuad(),
+ SdOptimizationV1(),
+ SdOptimizationInvokeAI(),
+ SdOptimizationDoggettx(),
+ ])
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
@@ -299,7 +428,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ kv_chunk_size = k_tokens
with devices.without_autocast(disable=q.dtype == v.dtype):
- return efficient_dot_product_attention(
+ return sub_quadratic_attention.efficient_dot_product_attention(
q,
k,
v,
|