From 0981dea94832f34d638b1aa8964cfaeffd223b47 Mon Sep 17 00:00:00 2001 From: Pam Date: Fri, 10 Mar 2023 12:58:10 +0500 Subject: sdp refactoring --- modules/sd_hijack.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f62e9adb..e98ae51a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,20 +37,21 @@ def apply_optimizations(): optimization_method = None + can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward optimization_method = 'xformers' - elif cmd_opts.opt_sdp_attention and (hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention"))): - if cmd_opts.opt_sdp_no_mem_attention: - print("Applying scaled dot product cross attention optimization (without memory efficient attention).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward - optimization_method = 'sdp-no-mem' - else: - print("Applying scaled dot product cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward - optimization_method = 'sdp' + elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp: + print("Applying scaled dot product cross attention optimization (without memory efficient attention).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward + optimization_method = 'sdp-no-mem' + elif cmd_opts.opt_sdp_attention and can_use_sdp: + print("Applying scaled dot product cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward + optimization_method = 'sdp' elif cmd_opts.opt_sub_quad_attention: print("Applying sub-quadratic cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward -- cgit v1.2.3