diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-07 09:26:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-07 09:26:55 +0000 |
commit | c295e4a2446bcc2663f497ba8afa14cec80de332 (patch) | |
tree | 606ede9bd1bf0c13b59c26a63755a2f95f6b8da6 /modules/sd_hijack.py | |
parent | 1a5b86ad65fd738eadea1ad72f4abad3a4aabf17 (diff) | |
parent | c18add68ef7d2de3617cbbaff864b0c74cfdf6c0 (diff) | |
download | stable-diffusion-webui-gfx803-c295e4a2446bcc2663f497ba8afa14cec80de332.tar.gz stable-diffusion-webui-gfx803-c295e4a2446bcc2663f497ba8afa14cec80de332.tar.bz2 stable-diffusion-webui-gfx803-c295e4a2446bcc2663f497ba8afa14cec80de332.zip |
Merge pull request #6055 from brkirch/sub-quad_attn_opt
Add Birch-san's sub-quadratic attention implementation
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 21 |
1 files changed, 9 insertions, 12 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 71cc145a..cfdb09d6 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
-from modules.sd_hijack_optimizations import invokeAI_mps_available
-
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
@@ -43,20 +41,19 @@ def apply_optimizations(): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
+ elif cmd_opts.opt_sub_quad_attention:
+ print("Applying sub-quadratic cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
+ optimization_method = 'sub-quadratic'
elif cmd_opts.opt_split_attention_v1:
print("Applying v1 cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
optimization_method = 'V1'
- elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()):
- if not invokeAI_mps_available and shared.device.type == 'mps':
- print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.")
- print("Applying v1 cross attention optimization.")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
- optimization_method = 'V1'
- else:
- print("Applying cross attention optimization (InvokeAI).")
- ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
- optimization_method = 'InvokeAI'
+ elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
+ print("Applying cross attention optimization (InvokeAI).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
+ optimization_method = 'InvokeAI'
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
print("Applying cross attention optimization (Doggettx).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
|