From 772db721a52da374d627b60994222051f26c27a7 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Fri, 7 Oct 2022 23:02:07 +0900 Subject: fix glob path in hypernetwork.py --- modules/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index c7b86682..7f062242 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -43,7 +43,7 @@ class Hypernetwork: def load_hypernetworks(path): res = {} - for filename in glob.iglob(path + '**/*.pt', recursive=True): + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): try: hn = Hypernetwork(filename) res[hn.name] = hn -- cgit v1.2.3 From 122d42687b97ec4df4c2a8c335d2de385cd1f1a1 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 22:37:35 -0400 Subject: Fix VRAM Issue by only loading in hypernetwork when selected in settings --- modules/hypernetwork.py | 23 +++++++++++++++-------- modules/sd_hijack_optimizations.py | 6 +++--- modules/shared.py | 7 ++----- webui.py | 3 +++ 4 files changed, 23 insertions(+), 16 deletions(-) (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 7f062242..19f1c227 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -40,18 +40,25 @@ class Hypernetwork: self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) -def load_hypernetworks(path): +def list_hypernetworks(path): res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res + + +def load_hypernetwork(filename): + print(f"Loading hypernetwork {filename}") + path = shared.hypernetworks.get(filename, None) + if (path is not None): try: - hn = Hypernetwork(filename) - res[hn.name] = hn + shared.loaded_hypernetwork = Hypernetwork(path) except Exception: - print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - - return res + else: + shared.loaded_hypernetwork = None def attention_CrossAttention_forward(self, x, context=None, mask=None): @@ -60,7 +67,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index c4396bb9..634fb4b2 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -28,7 +28,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: @@ -68,7 +68,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: @@ -132,7 +132,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: k_in = self.to_k(hypernetwork_layers[0](context)) diff --git a/modules/shared.py b/modules/shared.py index b2c76a32..9dce6cb7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -79,11 +79,8 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file -hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) - - -def selected_hypernetwork(): - return hypernetworks.get(opts.sd_hypernetwork, None) +hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) +loaded_hypernetwork = None class State: diff --git a/webui.py b/webui.py index 18de8e16..270584f7 100644 --- a/webui.py +++ b/webui.py @@ -82,6 +82,9 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts")) shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) +loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) +shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) + def webui(): # make the program just exit at ctrl+c without waiting for anything -- cgit v1.2.3 From 542a3d3a4a00c1383fbdaf938ceefef87cf834bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:33:22 +0300 Subject: fix btoken hypernetworks in XY plot --- modules/hypernetwork.py | 7 +++++-- scripts/xy_grid.py | 9 +++------ 2 files changed, 8 insertions(+), 8 deletions(-) (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 19f1c227..498bc9d8 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -49,15 +49,18 @@ def list_hypernetworks(path): def load_hypernetwork(filename): - print(f"Loading hypernetwork {filename}") path = shared.hypernetworks.get(filename, None) - if (path is not None): + if path is not None: + print(f"Loading hypernetwork {filename}") try: shared.loaded_hypernetwork = Hypernetwork(path) except Exception: print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + shared.loaded_hypernetwork = None diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index a8f53bef..fe949067 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images +from modules import images, hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs): def apply_hypernetwork(p, x, xs): - hn = shared.hypernetworks.get(x, None) - opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + hypernetwork.load_hypernetwork(x) def format_value_add_label(p, opt, x): @@ -203,8 +202,6 @@ class Script(scripts.Script): p.batch_size = 1 - initial_hn = opts.sd_hypernetwork - def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -321,6 +318,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) - opts.data["sd_hypernetwork"] = initial_hn + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) return processed -- cgit v1.2.3 From 948533950c9db5069a874d925fadd50bac00fdb5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 11:09:51 +0300 Subject: replace duplicate code with a function --- modules/hypernetwork.py | 23 ++++++++++++-------- modules/sd_hijack_optimizations.py | 44 +++++++++++++------------------------- 2 files changed, 29 insertions(+), 38 deletions(-) (limited to 'modules/hypernetwork.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 498bc9d8..7bbc443e 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -64,21 +64,26 @@ def load_hypernetwork(filename): shared.loaded_hypernetwork = None +def apply_hypernetwork(hypernetwork, context): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is None: + return context, context + + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v + + def attention_CrossAttention_forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k = self.to_k(hypernetwork_layers[0](context)) - v = self.to_v(hypernetwork_layers[1](context)) - else: - k = self.to_k(context) - v = self.to_v(context) + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) + v = self.to_v(context_v) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 18408e62..25cb67a4 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -8,7 +8,8 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared +from modules import shared, hypernetwork + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -26,16 +27,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) - del context, x + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + del context, context_k, context_v, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in @@ -59,22 +54,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): return self.to_out(r2) -# taken from https://github.com/Doggettx/stable-diffusion +# taken from https://github.com/Doggettx/stable-diffusion and modified def split_cross_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) k_in *= self.scale @@ -130,14 +119,11 @@ def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.loaded_hypernetwork - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - if hypernetwork_layers is not None: - k_in = self.to_k(hypernetwork_layers[0](context)) - v_in = self.to_v(hypernetwork_layers[1](context)) - else: - k_in = self.to_k(context) - v_in = self.to_v(context) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) -- cgit v1.2.3