From 2582a0fd3b3e91c5fba9e5e561cbdf5fee835063 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 18 May 2023 22:48:28 +0300 Subject: make it possible for scripts to add cross attention optimizations add UI selection for cross attention optimization --- modules/sd_hijack.py | 90 ++++++++++++++++++++++++++++------------------------ 1 file changed, 49 insertions(+), 41 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 14e7f799..39193be8 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -3,8 +3,9 @@ from torch.nn.functional import silu from types import MethodType import modules.textual_inversion.textual_inversion -from modules import devices, sd_hijack_optimizations, shared +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors from modules.hypernetworks import hypernetwork +from modules.sd_hijack_optimizations import diffusionmodules_model_AttnBlock_forward from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr @@ -28,57 +29,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] ldm.modules.attention.print = lambda *args: None ldm.modules.diffusionmodules.model.print = lambda *args: None +optimizers = [] +current_optimizer: sd_hijack_optimizations.SdOptimization = None + + +def list_optimizers(): + new_optimizers = script_callbacks.list_optimizers_callback() + + new_optimizers = [x for x in new_optimizers if x.is_available()] + + new_optimizers = sorted(new_optimizers, key=lambda x: x.priority(), reverse=True) + + optimizers.clear() + optimizers.extend(new_optimizers) + def apply_optimizations(): + global current_optimizer + undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th - optimization_method = None + if current_optimizer is not None: + current_optimizer.undo() + current_optimizer = None + + selection = shared.opts.cross_attention_optimization + if selection == "Automatic" and len(optimizers) > 0: + matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0]) + else: + matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None) - can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp - - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): - print("Applying xformers cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward - optimization_method = 'xformers' - elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp: - print("Applying scaled dot product cross attention optimization (without memory efficient attention).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward - optimization_method = 'sdp-no-mem' - elif cmd_opts.opt_sdp_attention and can_use_sdp: - print("Applying scaled dot product cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward - optimization_method = 'sdp' - elif cmd_opts.opt_sub_quad_attention: - print("Applying sub-quadratic cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward - optimization_method = 'sub-quadratic' - elif cmd_opts.opt_split_attention_v1: - print("Applying v1 cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - optimization_method = 'V1' - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()): - print("Applying cross attention optimization (InvokeAI).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI - optimization_method = 'InvokeAI' - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - print("Applying cross attention optimization (Doggettx).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - optimization_method = 'Doggettx' - - return optimization_method + if selection == "None": + matching_optimizer = None + elif matching_optimizer is None: + matching_optimizer = optimizers[0] + + if matching_optimizer is not None: + print(f"Applying optimization: {matching_optimizer.name}") + matching_optimizer.apply() + current_optimizer = matching_optimizer + return current_optimizer.name + else: + return '' def undo_optimizations(): - ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward @@ -169,7 +169,11 @@ class StableDiffusionModelHijack: if m.cond_stage_key == "edit": sd_hijack_unet.hijack_ddpm_edit() - self.optimization_method = apply_optimizations() + try: + self.optimization_method = apply_optimizations() + except Exception as e: + errors.display(e, "applying cross attention optimization") + undo_optimizations() self.clip = m.cond_stage_model @@ -223,6 +227,10 @@ class StableDiffusionModelHijack: return token_count, self.clip.get_target_prompt_token_count(token_count) + def redo_hijack(self, m): + self.undo_hijack(m) + self.hijack(m) + class EmbeddingsWithFixes(torch.nn.Module): def __init__(self, wrapped, embeddings): -- cgit v1.2.3 From 8a3d232839930376898634f65bd6c16f3a41e5b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 19 May 2023 00:03:27 +0300 Subject: fix linter issues --- modules/sd_hijack.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 39193be8..75f1c540 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,6 @@ from types import MethodType import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors from modules.hypernetworks import hypernetwork -from modules.sd_hijack_optimizations import diffusionmodules_model_AttnBlock_forward from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr -- cgit v1.2.3 From 2140bd1c108dd17bbf8601b10da7865ed1ac1607 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 19 May 2023 10:05:07 +0300 Subject: make it actually work after suggestions --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 75f1c540..08d31080 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,7 +37,7 @@ def list_optimizers(): new_optimizers = [x for x in new_optimizers if x.is_available()] - new_optimizers = sorted(new_optimizers, key=lambda x: x.priority(), reverse=True) + new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True) optimizers.clear() optimizers.extend(new_optimizers) -- cgit v1.2.3