diff options
author | Karun <karun.ellango7@gmail.com> | 2023-03-25 09:12:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-25 09:12:55 +0000 |
commit | 63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b (patch) | |
tree | 9a7c38070d83b409895704125525dfc70cc21215 /modules/sd_hijack.py | |
parent | ca2b8faa83076a21dd14c974f03f88eb6da57485 (diff) | |
parent | 70615448b2ef3285dba9bb1992974cb1eaf10995 (diff) | |
download | stable-diffusion-webui-gfx803-63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b.tar.gz stable-diffusion-webui-gfx803-63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b.tar.bz2 stable-diffusion-webui-gfx803-63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b.zip |
Merge branch 'master' into master
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 79476783..f4bb0266 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,11 +37,23 @@ def apply_optimizations(): optimization_method = None
+ can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
+
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
+ elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
+ print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
+ optimization_method = 'sdp-no-mem'
+ elif cmd_opts.opt_sdp_attention and can_use_sdp:
+ print("Applying scaled dot product cross attention optimization.")
+ ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
+ ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
+ optimization_method = 'sdp'
elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
|