From 1c5604791da7e57f40880698666b6617a1754c65 Mon Sep 17 00:00:00 2001 From: DoTheSneedful Date: Mon, 3 Oct 2022 22:20:09 -0400 Subject: Add a prompt order option to XY plot script --- scripts/xy_grid.py | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 146663b0..044c30e6 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -1,5 +1,6 @@ from collections import namedtuple from copy import copy +from itertools import permutations import random from PIL import Image @@ -28,6 +29,27 @@ def apply_prompt(p, x, xs): p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) +def apply_order(p, x, xs): + token_order = [] + + # Initally grab the tokens from the prompt so they can be later be replaced in order of earliest seen in the prompt + for token in x: + token_order.append((p.prompt.find(token), token)) + + token_order.sort(key=lambda t: t[0]) + + search_from_pos = 0 + for idx, token in enumerate(x): + original_pos, old_token = token_order[idx] + + # Get position of the token again as it will likely change as tokens are being replaced + pos = p.prompt.find(old_token) + if original_pos >= 0: + # Avoid trying to replace what was just replaced by searching later in the prompt string + p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, token, 1) + + search_from_pos = pos + len(token) + samplers_dict = {} for i, sampler in enumerate(modules.sd_samplers.samplers): @@ -60,7 +82,8 @@ def format_value_add_label(p, opt, x): def format_value(p, opt, x): if type(x) == float: x = round(x, 8) - + if type(x) == type(list()): + x = str(x) return x def do_nothing(p, x, xs): @@ -89,6 +112,7 @@ axis_options = [ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label), + AxisOption("Prompt order", type(list()), apply_order, format_value), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -159,7 +183,11 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = [x.strip() for x in vals.split(",")] + if opt.type == type(list()): + valslist = [x for x in vals] + else: + valslist = [x.strip() for x in vals.split(",")] + if opt.type == int: valslist_ext = [] @@ -212,9 +240,17 @@ class Script(scripts.Script): return valslist x_opt = axis_options[x_type] + + if x_opt.label == "Prompt order": + x_values = list(permutations([x.strip() for x in x_values.split(",")])) + xs = process_axis(x_opt, x_values) y_opt = axis_options[y_type] + + if y_opt.label == "Prompt order": + y_values = list(permutations([y.strip() for y in y_values.split(",")])) + ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): -- cgit v1.2.3 From 1a6d40db35656083d5bf9d3a3430b45fda4e85eb Mon Sep 17 00:00:00 2001 From: DoTheSneedful Date: Tue, 4 Oct 2022 00:18:15 -0400 Subject: Fix token ordering in prompt order XY plot --- scripts/xy_grid.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 044c30e6..5bcd3921 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -32,24 +32,21 @@ def apply_prompt(p, x, xs): def apply_order(p, x, xs): token_order = [] - # Initally grab the tokens from the prompt so they can be later be replaced in order of earliest seen in the prompt + # Initally grab the tokens from the prompt so they can be be replaced in order of earliest seen for token in x: token_order.append((p.prompt.find(token), token)) token_order.sort(key=lambda t: t[0]) search_from_pos = 0 - for idx, token in enumerate(x): - original_pos, old_token = token_order[idx] - + for idx, (original_pos, old_token) in enumerate(token_order): # Get position of the token again as it will likely change as tokens are being replaced - pos = p.prompt.find(old_token) + pos = search_from_pos + p.prompt[search_from_pos:].find(old_token) if original_pos >= 0: # Avoid trying to replace what was just replaced by searching later in the prompt string - p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, token, 1) - - search_from_pos = pos + len(token) + p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, x[idx], 1) + search_from_pos = pos + len(x[idx]) samplers_dict = {} for i, sampler in enumerate(modules.sd_samplers.samplers): -- cgit v1.2.3 From 56371153b545e3a43c3a5f206264019af361f3af Mon Sep 17 00:00:00 2001 From: DoTheSneedful Date: Tue, 4 Oct 2022 01:07:36 -0400 Subject: XY plot prompt order simplify logic --- scripts/xy_grid.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 5bcd3921..7def47f5 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -38,15 +38,21 @@ def apply_order(p, x, xs): token_order.sort(key=lambda t: t[0]) - search_from_pos = 0 - for idx, (original_pos, old_token) in enumerate(token_order): - # Get position of the token again as it will likely change as tokens are being replaced - pos = search_from_pos + p.prompt[search_from_pos:].find(old_token) - if original_pos >= 0: - # Avoid trying to replace what was just replaced by searching later in the prompt string - p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, x[idx], 1) - - search_from_pos = pos + len(x[idx]) + prompt_parts = [] + + # Split the prompt up, taking out the tokens + for _, token in token_order: + n = p.prompt.find(token) + prompt_parts.append(p.prompt[0:n]) + p.prompt = p.prompt[n + len(token):] + + # Rebuild the prompt with the tokens in the order we want + prompt_tmp = "" + for idx, part in enumerate(prompt_parts): + prompt_tmp += part + prompt_tmp += x[idx] + p.prompt = prompt_tmp + p.prompt + samplers_dict = {} for i, sampler in enumerate(modules.sd_samplers.samplers): -- cgit v1.2.3 From 556c36b9607e3f4eacdddc85f8e7a78b29476ea7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 09:18:00 +0300 Subject: add hint, refactor code for #1607 --- javascript/hints.js | 1 + scripts/xy_grid.py | 35 ++++++++++++++++++----------------- 2 files changed, 19 insertions(+), 17 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/javascript/hints.js b/javascript/hints.js index e72e9338..8adcd983 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -47,6 +47,7 @@ titles = { "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", + "Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order", "Tiling": "Produce an image that can be tiled.", "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.", diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 7def47f5..1237e754 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -29,10 +29,11 @@ def apply_prompt(p, x, xs): p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) + def apply_order(p, x, xs): token_order = [] - # Initally grab the tokens from the prompt so they can be be replaced in order of earliest seen + # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen for token in x: token_order.append((p.prompt.find(token), token)) @@ -85,17 +86,26 @@ def format_value_add_label(p, opt, x): def format_value(p, opt, x): if type(x) == float: x = round(x, 8) - if type(x) == type(list()): - x = str(x) return x + +def format_value_join_list(p, opt, x): + return ", ".join(x) + + def do_nothing(p, x, xs): pass + def format_nothing(p, opt, x): return "" +def str_permutations(x): + """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" + return x + + AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) @@ -108,6 +118,7 @@ axis_options = [ AxisOption("Steps", int, apply_field("steps"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("Prompt S/R", str, apply_prompt, format_value), + AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), @@ -115,7 +126,6 @@ axis_options = [ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label), - AxisOption("Prompt order", type(list()), apply_order, format_value), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -158,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") + class Script(scripts.Script): def title(self): return "X/Y plot" @@ -186,11 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - if opt.type == type(list()): - valslist = [x for x in vals] - else: - valslist = [x.strip() for x in vals.split(",")] - + valslist = [x.strip() for x in vals.split(",")] if opt.type == int: valslist_ext = [] @@ -237,23 +244,17 @@ class Script(scripts.Script): valslist_ext.append(val) valslist = valslist_ext + elif opt.type == str_permutations: + valslist = list(permutations(valslist)) valslist = [opt.type(x) for x in valslist] return valslist x_opt = axis_options[x_type] - - if x_opt.label == "Prompt order": - x_values = list(permutations([x.strip() for x in x_values.split(",")])) - xs = process_axis(x_opt, x_values) y_opt = axis_options[y_type] - - if y_opt.label == "Prompt order": - y_values = list(permutations([y.strip() for y in y_values.split(",")])) - ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): -- cgit v1.2.3 From 5d0e6ab8567bda2ee8f5ed31f332ca07c1b84b98 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 04:04:50 +0100 Subject: Allow escaping of commas in xy_grid --- scripts/xy_grid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 1237e754..210829a7 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -168,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") +re_non_escaped_comma = re.compile(r"(? Date: Thu, 6 Oct 2022 11:55:21 +0100 Subject: use csv.reader --- scripts/xy_grid.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 210829a7..1a625898 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -1,8 +1,9 @@ from collections import namedtuple from copy import copy -from itertools import permutations +from itertools import permutations, chain import random - +import csv +from io import StringIO from PIL import Image import numpy as np @@ -168,8 +169,6 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") -re_non_escaped_comma = re.compile(r"(? Date: Thu, 6 Oct 2022 12:32:17 +0100 Subject: strip() split comma delimited lines --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 1a625898..ec27e58b 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -197,7 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = list(chain.from_iterable(csv.reader(StringIO(s)))) + valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(s))))) if opt.type == int: valslist_ext = [] -- cgit v1.2.3 From 82eb8ea452b1e63535c58d15ec6db2ad2342faa8 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 15:22:51 +0100 Subject: Update xy_grid.py split vals not 's' from tests --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index ec27e58b..210c7b6e 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -197,7 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(s))))) + valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(vals))))) if opt.type == int: valslist_ext = [] -- cgit v1.2.3 From 1069ec49a35d04c1e85c92534e92a2d6aa59cb75 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 20:16:21 +0300 Subject: revert back to using list comprehension rather than list and map --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 210c7b6e..6344e612 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -197,7 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(vals))))) + valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))] if opt.type == int: valslist_ext = [] -- cgit v1.2.3 From bad7cb29cecac51c5c0f39afec332b007ed73133 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 10:17:52 +0300 Subject: added support for hypernetworks (???) --- modules/hypernetwork.py | 55 ++++++++++++++++++++++++++++++++++++++ modules/sd_hijack_optimizations.py | 17 ++++++++++-- modules/shared.py | 9 ++++++- scripts/xy_grid.py | 10 +++++++ 4 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 modules/hypernetwork.py (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py new file mode 100644 index 00000000..9ed1eed9 --- /dev/null +++ b/modules/hypernetwork.py @@ -0,0 +1,55 @@ +import glob +import os +import torch +from modules import devices + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + self.load_state_dict(state_dict, strict=True) + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, filename): + self.filename = filename + self.name = os.path.splitext(os.path.basename(filename))[0] + self.layers = {} + + state_dict = torch.load(filename, map_location='cpu') + for size, sd in state_dict.items(): + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + +def load_hypernetworks(path): + res = {} + + for filename in glob.iglob(path + '**/*.pt', recursive=True): + hn = Hypernetwork(filename) + res[hn.name] = hn + + return res + +def apply(self, x, context=None, mask=None, original=None): + + + if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: + if context.shape[1] == 77 and CrossAttention.noise_cond: + context = context + (torch.randn_like(context) * 0.1) + h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] + k = self.to_k(h_k(context)) + v = self.to_v(h_v(context)) + else: + k = self.to_k(context) + v = self.to_v(context) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfc..d9cca485 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,8 @@ from torch import einsum from ldm.util import default from einops import rearrange +from modules import shared + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): @@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) + + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) + + k_in *= self.scale + del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) diff --git a/modules/shared.py b/modules/shared.py index 25bb6e6c..879d8424 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers +from modules import sd_samplers, hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -76,6 +76,12 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram config_filename = cmd_opts.ui_settings_file +hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) + + +def selected_hypernetwork(): + return hypernetworks.get(opts.sd_hypernetwork, None) + class State: interrupted = False @@ -206,6 +212,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6344e612..c0c364df 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -77,6 +77,11 @@ def apply_checkpoint(p, x, xs): modules.sd_models.reload_model_weights(shared.sd_model, info) +def apply_hypernetwork(p, x, xs): + hn = shared.hypernetworks.get(x, None) + opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -122,6 +127,7 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), @@ -193,6 +199,8 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 + initial_hn = opts.sd_hypernetwork + def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -300,4 +308,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) + opts.data["sd_hypernetwork"] = initial_hn + return processed -- cgit v1.2.3 From 03e570886f430f39020e504aba057a95f2e62484 Mon Sep 17 00:00:00 2001 From: frostydad <64224601+Cyberes@users.noreply.github.com> Date: Sat, 8 Oct 2022 18:13:13 -0600 Subject: Fix incorrect sampler name in output --- modules/processing.py | 9 ++++++++- scripts/xy_grid.py | 16 +++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/processing.py b/modules/processing.py index 4fea6d56..6b8664a0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,3 +1,4 @@ + import json import math import os @@ -46,6 +47,12 @@ def apply_color_correction(correction, image): return image +def get_correct_sampler(p): + if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img): + return sd_samplers.samplers + elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): + return sd_samplers.samplers_for_img2img + class StableDiffusionProcessing: def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): self.sd_model = sd_model @@ -272,7 +279,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params = { "Steps": p.steps, - "Sampler": sd_samplers.samplers[p.sampler_index].name, + "Sampler": get_correct_sampler(p)[p.sampler_index].name, "CFG scale": p.cfg_scale, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index c0c364df..26ae2199 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,7 @@ import modules.scripts as scripts import gradio as gr from modules import images -from modules.processing import process_images, Processed +from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.sd_samplers @@ -56,15 +56,17 @@ def apply_order(p, x, xs): p.prompt = prompt_tmp + p.prompt -samplers_dict = {} -for i, sampler in enumerate(modules.sd_samplers.samplers): - samplers_dict[sampler.name.lower()] = i - for alias in sampler.aliases: - samplers_dict[alias.lower()] = i +def build_samplers_dict(p): + samplers_dict = {} + for i, sampler in enumerate(get_correct_sampler(p)): + samplers_dict[sampler.name.lower()] = i + for alias in sampler.aliases: + samplers_dict[alias.lower()] = i + return samplers_dict def apply_sampler(p, x, xs): - sampler_index = samplers_dict.get(x.lower(), None) + sampler_index = build_samplers_dict(p).get(x.lower(), None) if sampler_index is None: raise RuntimeError(f"Unknown sampler: {x}") -- cgit v1.2.3 From d74c38108f95e44d83a1706ee5ab218124972868 Mon Sep 17 00:00:00 2001 From: Jesse Williams <33797815+xram64@users.noreply.github.com> Date: Sat, 8 Oct 2022 01:30:49 -0400 Subject: Confirm that options are valid before starting When using the 'Sampler' or 'Checkpoint' options, if one of the entered names has a typo, an error will only be thrown once the `draw_xy_grid` loop reaches that name. This can waste a lot of time for large grids with a typo near the end of a list, since the script needs to start over and re-generate any earlier images to finish making the grid. Also fixing typo in variable name in `draw_xy_grid`. --- scripts/xy_grid.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 26ae2199..07040886 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -145,7 +145,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] - first_pocessed = None + first_processed = None state.job_count = len(xs) * len(ys) * p.n_iter @@ -154,8 +154,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" processed = cell(x, y) - if first_pocessed is None: - first_pocessed = processed + if first_processed is None: + first_processed = processed try: res.append(processed.images[0]) @@ -166,9 +166,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): if draw_legend: grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) - first_pocessed.images = [grid] + first_processed.images = [grid] - return first_pocessed + return first_processed re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") @@ -216,7 +216,6 @@ class Script(scripts.Script): m = re_range.fullmatch(val) mc = re_range_count.fullmatch(val) if m is not None: - start = int(m.group(1)) end = int(m.group(2))+1 step = int(m.group(3)) if m.group(3) is not None else 1 @@ -258,6 +257,16 @@ class Script(scripts.Script): valslist = list(permutations(valslist)) valslist = [opt.type(x) for x in valslist] + + # Confirm options are valid before starting + if opt.label == "Sampler": + for sampler_val in valslist: + if sampler_val.lower() not in samplers_dict.keys(): + raise RuntimeError(f"Unknown sampler: {sampler_val}") + elif opt.label == "Checkpoint name": + for ckpt_val in valslist: + if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None: + raise RuntimeError(f"Checkpoint for {ckpt_val} not found") return valslist -- cgit v1.2.3 From a65a45272e8f26ee3bc52a5300b396266508a9a5 Mon Sep 17 00:00:00 2001 From: Brendan Byrd Date: Thu, 6 Oct 2022 19:31:36 -0400 Subject: Don't change the seed initially if "Keep -1 for seeds" is checked Fixes #1049 --- scripts/xy_grid.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 07040886..a8f53bef 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -198,7 +198,9 @@ class Script(scripts.Script): return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds] def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds): - modules.processing.fix_seed(p) + if not no_fixed_seeds: + modules.processing.fix_seed(p) + p.batch_size = 1 initial_hn = opts.sd_hypernetwork -- cgit v1.2.3 From 542a3d3a4a00c1383fbdaf938ceefef87cf834bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:33:22 +0300 Subject: fix btoken hypernetworks in XY plot --- modules/hypernetwork.py | 7 +++++-- scripts/xy_grid.py | 9 +++------ 2 files changed, 8 insertions(+), 8 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 19f1c227..498bc9d8 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -49,15 +49,18 @@ def list_hypernetworks(path): def load_hypernetwork(filename): - print(f"Loading hypernetwork {filename}") path = shared.hypernetworks.get(filename, None) - if (path is not None): + if path is not None: + print(f"Loading hypernetwork {filename}") try: shared.loaded_hypernetwork = Hypernetwork(path) except Exception: print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + shared.loaded_hypernetwork = None diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index a8f53bef..fe949067 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images +from modules import images, hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs): def apply_hypernetwork(p, x, xs): - hn = shared.hypernetworks.get(x, None) - opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + hypernetwork.load_hypernetwork(x) def format_value_add_label(p, opt, x): @@ -203,8 +202,6 @@ class Script(scripts.Script): p.batch_size = 1 - initial_hn = opts.sd_hypernetwork - def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -321,6 +318,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) - opts.data["sd_hypernetwork"] = initial_hn + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) return processed -- cgit v1.2.3 From 2c52f4da7ff80a3ec277105f4db6146c6379898a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 15:01:42 +0300 Subject: fix broken samplers in XY plot --- scripts/xy_grid.py | 1 + 1 file changed, 1 insertion(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index fe949067..c89ca1a9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -259,6 +259,7 @@ class Script(scripts.Script): # Confirm options are valid before starting if opt.label == "Sampler": + samplers_dict = build_samplers_dict(p) for sampler_val in valslist: if sampler_val.lower() not in samplers_dict.keys(): raise RuntimeError(f"Unknown sampler: {sampler_val}") -- cgit v1.2.3 From 45bf9a6264b3507473e02cc3f9aa36559f24aca2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 18:58:55 +0300 Subject: added clip skip to XY plot --- scripts/xy_grid.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index c89ca1a9..7b0d9083 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -83,6 +83,10 @@ def apply_hypernetwork(p, x, xs): hypernetwork.load_hypernetwork(x) +def apply_clip_skip(p, x, xs): + opts.data["CLIP_ignore_last_layers"] = x + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -134,6 +138,7 @@ axis_options = [ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label), + AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -201,6 +206,7 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 + CLIP_ignore_last_layers = opts.CLIP_ignore_last_layers def process_axis(opt, vals): if opt.label == 'Nothing': @@ -321,4 +327,6 @@ class Script(scripts.Script): hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + opts.data["CLIP_ignore_last_layers"] = CLIP_ignore_last_layers + return processed -- cgit v1.2.3 From 84ddd44113b36062e8ba6cb2e5db0fce4f48efb8 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sun, 9 Oct 2022 14:57:17 -0400 Subject: Clip skip variable name change breaks x/y plot script. This fixes that --- scripts/xy_grid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 7b0d9083..771eb8e4 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -84,7 +84,7 @@ def apply_hypernetwork(p, x, xs): def apply_clip_skip(p, x, xs): - opts.data["CLIP_ignore_last_layers"] = x + opts.data["CLIP_stop_at_last_layers"] = x def format_value_add_label(p, opt, x): @@ -206,7 +206,7 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 - CLIP_ignore_last_layers = opts.CLIP_ignore_last_layers + CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers def process_axis(opt, vals): if opt.label == 'Nothing': @@ -327,6 +327,6 @@ class Script(scripts.Script): hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - opts.data["CLIP_ignore_last_layers"] = CLIP_ignore_last_layers + opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers return processed -- cgit v1.2.3 From 5da1ba0e91a81804dc911d34c9a2e6956a23199c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 21:24:11 +0300 Subject: remove batch size restriction from X/Y plot --- scripts/xy_grid.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 771eb8e4..42e1489c 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -205,7 +205,10 @@ class Script(scripts.Script): if not no_fixed_seeds: modules.processing.fix_seed(p) - p.batch_size = 1 + if not opts.return_grid: + p.batch_size = 1 + + CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers def process_axis(opt, vals): -- cgit v1.2.3 From 255be75d30f41e089e499ec1c8462d6bf64dec24 Mon Sep 17 00:00:00 2001 From: aperullo <18688190+aperullo@users.noreply.github.com> Date: Tue, 11 Oct 2022 06:16:57 -0400 Subject: Error if prompt missing SR token to prevent mis-gens (#2209) --- scripts/xy_grid.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 42e1489c..10a82dc9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -27,9 +27,16 @@ def apply_field(field): def apply_prompt(p, x, xs): + + orig_prompt = p.prompt + orig_negative_prompt = p.negative_prompt + p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) + if p.prompt == orig_prompt and p.negative_prompt == orig_negative_prompt: + raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt. Did you forget to add the token?") + def apply_order(p, x, xs): token_order = [] -- cgit v1.2.3 From a8490e4019c359ff24824e004059744d7164361b Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 11:42:41 +0100 Subject: revert sr warning --- scripts/xy_grid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 10a82dc9..99b3c4f6 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -35,7 +35,8 @@ def apply_prompt(p, x, xs): p.negative_prompt = p.negative_prompt.replace(xs[0], x) if p.prompt == orig_prompt and p.negative_prompt == orig_negative_prompt: - raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt. Did you forget to add the token?") + pass + #raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt. Did you forget to add the token?") def apply_order(p, x, xs): -- cgit v1.2.3 From 1a0a6a84c3149e236211d547471f5416cd1129f3 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 11:59:56 +0100 Subject: add incorrect start word guard to xy_grid (#2259) --- scripts/xy_grid.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 99b3c4f6..9d4d6187 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -27,17 +27,12 @@ def apply_field(field): def apply_prompt(p, x, xs): - - orig_prompt = p.prompt - orig_negative_prompt = p.negative_prompt + if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: + raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.") p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) - if p.prompt == orig_prompt and p.negative_prompt == orig_negative_prompt: - pass - #raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt. Did you forget to add the token?") - def apply_order(p, x, xs): token_order = [] -- cgit v1.2.3 From 530103b586109c11fd068eb70ef09503ec6a4caf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 14:53:02 +0300 Subject: fixes related to merge --- modules/hypernetwork.py | 103 ------------------------- modules/hypernetwork/hypernetwork.py | 74 +++++++++++------- modules/hypernetwork/ui.py | 10 +-- modules/sd_hijack_optimizations.py | 3 +- modules/shared.py | 13 +++- modules/textual_inversion/textual_inversion.py | 12 +-- modules/ui.py | 5 +- scripts/xy_grid.py | 3 +- webui.py | 15 +--- 9 files changed, 78 insertions(+), 160 deletions(-) delete mode 100644 modules/hypernetwork.py (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py deleted file mode 100644 index 7bbc443e..00000000 --- a/modules/hypernetwork.py +++ /dev/null @@ -1,103 +0,0 @@ -import glob -import os -import sys -import traceback - -import torch - -from ldm.util import default -from modules import devices, shared -import torch -from torch import einsum -from einops import rearrange, repeat - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, state_dict): - super().__init__() - - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) - - self.load_state_dict(state_dict, strict=True) - self.to(devices.device) - - def forward(self, x): - return x + (self.linear2(self.linear1(x))) - - -class Hypernetwork: - filename = None - name = None - - def __init__(self, filename): - self.filename = filename - self.name = os.path.splitext(os.path.basename(filename))[0] - self.layers = {} - - state_dict = torch.load(filename, map_location='cpu') - for size, sd in state_dict.items(): - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) - - -def list_hypernetworks(path): - res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): - name = os.path.splitext(os.path.basename(filename))[0] - res[name] = filename - return res - - -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - if path is not None: - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork(path) - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print(f"Unloading hypernetwork") - - shared.loaded_hypernetwork = None - - -def apply_hypernetwork(hypernetwork, context): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - - -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context) - k = self.to_k(context_k) - v = self.to_v(context_v) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if mask is not None: - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py index a3d6a47e..aa701bda 100644 --- a/modules/hypernetwork/hypernetwork.py +++ b/modules/hypernetwork/hypernetwork.py @@ -26,10 +26,11 @@ class HypernetworkModule(torch.nn.Module): if state_dict is not None: self.load_state_dict(state_dict, strict=True) else: - self.linear1.weight.data.fill_(0.0001) - self.linear1.bias.data.fill_(0.0001) - self.linear2.weight.data.fill_(0.0001) - self.linear2.bias.data.fill_(0.0001) + + self.linear1.weight.data.normal_(mean=0.0, std=0.01) + self.linear1.bias.data.zero_() + self.linear2.weight.data.normal_(mean=0.0, std=0.01) + self.linear2.bias.data.zero_() self.to(devices.device) @@ -92,41 +93,54 @@ class Hypernetwork: self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) -def load_hypernetworks(path): +def list_hypernetworks(path): res = {} + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res - for filename in glob.iglob(path + '**/*.pt', recursive=True): + +def load_hypernetwork(filename): + path = shared.hypernetworks.get(filename, None) + if path is not None: + print(f"Loading hypernetwork {filename}") try: - hn = Hypernetwork() - hn.load(filename) - res[hn.name] = hn + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + except Exception: - print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") - return res + shared.loaded_hypernetwork = None -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads +def apply_hypernetwork(hypernetwork, context, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - q = self.to_q(x) - context = default(context, x) + if hypernetwork_layers is None: + return context, context - hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None) + if layer is not None: + layer.hyper_k = hypernetwork_layers[0] + layer.hyper_v = hypernetwork_layers[1] - if hypernetwork_layers is not None: - hypernetwork_k, hypernetwork_v = hypernetwork_layers + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v - self.hypernetwork_k = hypernetwork_k - self.hypernetwork_v = hypernetwork_v - context_k = hypernetwork_k(context) - context_v = hypernetwork_v(context) - else: - context_k = context - context_v = context +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) k = self.to_k(context_k) v = self.to_v(context_v) @@ -151,7 +165,9 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): assert hypernetwork_name, 'embedding not selected' - shared.hypernetwork = shared.hypernetworks[hypernetwork_name] + path = shared.hypernetworks.get(hypernetwork_name, None) + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) shared.state.textinfo = "Initializing hypernetwork training..." shared.state.job_count = steps @@ -176,9 +192,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) - hypernetwork = shared.hypernetworks[hypernetwork_name] + hypernetwork = shared.loaded_hypernetwork weights = hypernetwork.weights() for weight in weights: weight.requires_grad = True @@ -194,7 +210,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps: return hypernetwork, filename - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, (x, text) in pbar: hypernetwork.step = i + ititial_step diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py index 525f978c..f6d1d0a3 100644 --- a/modules/hypernetwork/ui.py +++ b/modules/hypernetwork/ui.py @@ -6,24 +6,24 @@ import gradio as gr import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess from modules import sd_hijack, shared +from modules.hypernetwork import hypernetwork def create_hypernetwork(name): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) - hypernetwork.save(fn) + hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet.save(fn) shared.reload_hypernetworks() - shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" def train_hypernetwork(*args): - initial_hypernetwork = shared.hypernetwork + initial_hypernetwork = shared.loaded_hypernetwork try: sd_hijack.undo_optimizations() @@ -38,6 +38,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception: raise finally: - shared.hypernetwork = initial_hypernetwork + shared.loaded_hypernetwork = initial_hypernetwork sd_hijack.apply_optimizations() diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 25cb67a4..27e571fc 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -8,7 +8,8 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, hypernetwork +from modules import shared +from modules.hypernetwork import hypernetwork if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: diff --git a/modules/shared.py b/modules/shared.py index 14b40d70..8753015e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,8 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, hypernetwork +from modules import sd_samplers +from modules.hypernetwork import hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") @@ -82,10 +84,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file -hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) +hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None +def reload_hypernetworks(): + global hypernetworks + + hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + + class State: skipped = False interrupted = False diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5965c5a0..d6977950 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -238,12 +238,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') + preview_text = text if preview_image_prompt == "" else preview_image_prompt + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, - prompt=text, + prompt=preview_text, steps=20, - height=training_height, - width=training_width, + height=training_height, + width=training_width, do_not_save_grid=True, do_not_save_samples=True, ) @@ -254,7 +256,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.current_image = image image.save(last_saved_image) - last_saved_image += f", prompt: {text}" + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step diff --git a/modules/ui.py b/modules/ui.py index 10b1ee3a..df653059 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1023,7 +1023,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_embedding = gr.Button(value="Create", variant='primary') + create_embedding = gr.Button(value="Create embedding", variant='primary') with gr.Group(): gr.HTML(value="

Create a new hypernetwork

") @@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_hypernetwork = gr.Button(value="Create", variant='primary') + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') with gr.Group(): gr.HTML(value="

Preprocess images

") @@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call): create_image_every, save_embedding_every, template_file, + preview_image_prompt, ], outputs=[ ti_output, diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 42e1489c..0af5993c 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,8 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images, hypernetwork +from modules import images +from modules.hypernetwork import hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared diff --git a/webui.py b/webui.py index 7c200551..ba2156c8 100644 --- a/webui.py +++ b/webui.py @@ -29,6 +29,7 @@ from modules import devices from modules import modelloader from modules.paths import script_path from modules.shared import cmd_opts +import modules.hypernetwork.hypernetwork modelloader.cleanup_models() modules.sd_models.setup_model() @@ -77,22 +78,12 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) -def set_hypernetwork(): - shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) - - -shared.reload_hypernetworks() -shared.opts.onchange("sd_hypernetwork", set_hypernetwork) -set_hypernetwork() - - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) -loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) -shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) +shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) def webui(): @@ -117,7 +108,7 @@ def webui(): prevent_thread_lock=True ) - app.add_middleware(GZipMiddleware,minimum_size=1000) + app.add_middleware(GZipMiddleware, minimum_size=1000) while 1: time.sleep(0.5) -- cgit v1.2.3 From 873efeed49bb5197a42da18272115b326c5d68f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 15:51:22 +0300 Subject: rename hypernetwork dir to hypernetworks to prevent clash with an old filename that people who use zip instead of git clone will have --- modules/hypernetwork/hypernetwork.py | 283 ---------------------------------- modules/hypernetwork/ui.py | 43 ------ modules/hypernetworks/hypernetwork.py | 283 ++++++++++++++++++++++++++++++++++ modules/hypernetworks/ui.py | 43 ++++++ modules/sd_hijack.py | 2 +- modules/sd_hijack_optimizations.py | 2 +- modules/shared.py | 2 +- modules/ui.py | 2 +- scripts/xy_grid.py | 2 +- webui.py | 2 +- 10 files changed, 332 insertions(+), 332 deletions(-) delete mode 100644 modules/hypernetwork/hypernetwork.py delete mode 100644 modules/hypernetwork/ui.py create mode 100644 modules/hypernetworks/hypernetwork.py create mode 100644 modules/hypernetworks/ui.py (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py deleted file mode 100644 index aa701bda..00000000 --- a/modules/hypernetwork/hypernetwork.py +++ /dev/null @@ -1,283 +0,0 @@ -import datetime -import glob -import html -import os -import sys -import traceback -import tqdm - -import torch - -from ldm.util import default -from modules import devices, shared, processing, sd_models -import torch -from torch import einsum -from einops import rearrange, repeat -import modules.textual_inversion.dataset - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, state_dict=None): - super().__init__() - - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) - - if state_dict is not None: - self.load_state_dict(state_dict, strict=True) - else: - - self.linear1.weight.data.normal_(mean=0.0, std=0.01) - self.linear1.bias.data.zero_() - self.linear2.weight.data.normal_(mean=0.0, std=0.01) - self.linear2.bias.data.zero_() - - self.to(devices.device) - - def forward(self, x): - return x + (self.linear2(self.linear1(x))) - - -class Hypernetwork: - filename = None - name = None - - def __init__(self, name=None): - self.filename = None - self.name = name - self.layers = {} - self.step = 0 - self.sd_checkpoint = None - self.sd_checkpoint_name = None - - for size in [320, 640, 768, 1280]: - self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) - - def weights(self): - res = [] - - for k, layers in self.layers.items(): - for layer in layers: - layer.train() - res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] - - return res - - def save(self, filename): - state_dict = {} - - for k, v in self.layers.items(): - state_dict[k] = (v[0].state_dict(), v[1].state_dict()) - - state_dict['step'] = self.step - state_dict['name'] = self.name - state_dict['sd_checkpoint'] = self.sd_checkpoint - state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name - - torch.save(state_dict, filename) - - def load(self, filename): - self.filename = filename - if self.name is None: - self.name = os.path.splitext(os.path.basename(filename))[0] - - state_dict = torch.load(filename, map_location='cpu') - - for size, sd in state_dict.items(): - if type(size) == int: - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) - - self.name = state_dict.get('name', self.name) - self.step = state_dict.get('step', 0) - self.sd_checkpoint = state_dict.get('sd_checkpoint', None) - self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) - - -def list_hypernetworks(path): - res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): - name = os.path.splitext(os.path.basename(filename))[0] - res[name] = filename - return res - - -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - if path is not None: - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) - - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print(f"Unloading hypernetwork") - - shared.loaded_hypernetwork = None - - -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - if layer is not None: - layer.hyper_k = hypernetwork_layers[0] - layer.hyper_v = hypernetwork_layers[1] - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - - -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) - k = self.to_k(context_k) - v = self.to_v(context_v) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if mask is not None: - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) - - -def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): - assert hypernetwork_name, 'embedding not selected' - - path = shared.hypernetworks.get(hypernetwork_name, None) - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) - - shared.state.textinfo = "Initializing hypernetwork training..." - shared.state.job_count = steps - - filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') - - log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) - - if save_hypernetwork_every > 0: - hypernetwork_dir = os.path.join(log_directory, "hypernetworks") - os.makedirs(hypernetwork_dir, exist_ok=True) - else: - hypernetwork_dir = None - - if create_image_every > 0: - images_dir = os.path.join(log_directory, "images") - os.makedirs(images_dir, exist_ok=True) - else: - images_dir = None - - cond_model = shared.sd_model.cond_stage_model - - shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) - - hypernetwork = shared.loaded_hypernetwork - weights = hypernetwork.weights() - for weight in weights: - weight.requires_grad = True - - optimizer = torch.optim.AdamW(weights, lr=learn_rate) - - losses = torch.zeros((32,)) - - last_saved_file = "" - last_saved_image = "" - - ititial_step = hypernetwork.step or 0 - if ititial_step > steps: - return hypernetwork, filename - - pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text) in pbar: - hypernetwork.step = i + ititial_step - - if hypernetwork.step > steps: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = cond_model([text]) - - x = x.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] - del x - - losses[hypernetwork.step % losses.shape[0]] = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - pbar.set_description(f"loss: {losses.mean():.7f}") - - if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') - hypernetwork.save(last_saved_file) - - if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - - preview_text = text if preview_image_prompt == "" else preview_image_prompt - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - prompt=preview_text, - steps=20, - do_not_save_grid=True, - do_not_save_samples=True, - ) - - processed = processing.process_images(p) - image = processed.images[0] - - shared.state.current_image = image - image.save(last_saved_image) - - last_saved_image += f", prompt: {preview_text}" - - shared.state.job_no = hypernetwork.step - - shared.state.textinfo = f""" -

-Loss: {losses.mean():.7f}
-Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
-Last saved embedding: {html.escape(last_saved_file)}
-Last saved image: {html.escape(last_saved_image)}
-

-""" - - checkpoint = sd_models.select_checkpoint() - - hypernetwork.sd_checkpoint = checkpoint.hash - hypernetwork.sd_checkpoint_name = checkpoint.model_name - hypernetwork.save(filename) - - return hypernetwork, filename - - diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py deleted file mode 100644 index f6d1d0a3..00000000 --- a/modules/hypernetwork/ui.py +++ /dev/null @@ -1,43 +0,0 @@ -import html -import os - -import gradio as gr - -import modules.textual_inversion.textual_inversion -import modules.textual_inversion.preprocess -from modules import sd_hijack, shared -from modules.hypernetwork import hypernetwork - - -def create_hypernetwork(name): - fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" - - hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) - hypernet.save(fn) - - shared.reload_hypernetworks() - - return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" - - -def train_hypernetwork(*args): - - initial_hypernetwork = shared.loaded_hypernetwork - - try: - sd_hijack.undo_optimizations() - - hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) - - res = f""" -Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. -Hypernetwork saved to {html.escape(filename)} -""" - return res, "" - except Exception: - raise - finally: - shared.loaded_hypernetwork = initial_hypernetwork - sd_hijack.apply_optimizations() - diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py new file mode 100644 index 00000000..aa701bda --- /dev/null +++ b/modules/hypernetworks/hypernetwork.py @@ -0,0 +1,283 @@ +import datetime +import glob +import html +import os +import sys +import traceback +import tqdm + +import torch + +from ldm.util import default +from modules import devices, shared, processing, sd_models +import torch +from torch import einsum +from einops import rearrange, repeat +import modules.textual_inversion.dataset + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict=None): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + if state_dict is not None: + self.load_state_dict(state_dict, strict=True) + else: + + self.linear1.weight.data.normal_(mean=0.0, std=0.01) + self.linear1.bias.data.zero_() + self.linear2.weight.data.normal_(mean=0.0, std=0.01) + self.linear2.bias.data.zero_() + + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, name=None): + self.filename = None + self.name = name + self.layers = {} + self.step = 0 + self.sd_checkpoint = None + self.sd_checkpoint_name = None + + for size in [320, 640, 768, 1280]: + self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) + + def weights(self): + res = [] + + for k, layers in self.layers.items(): + for layer in layers: + layer.train() + res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] + + return res + + def save(self, filename): + state_dict = {} + + for k, v in self.layers.items(): + state_dict[k] = (v[0].state_dict(), v[1].state_dict()) + + state_dict['step'] = self.step + state_dict['name'] = self.name + state_dict['sd_checkpoint'] = self.sd_checkpoint + state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name + + torch.save(state_dict, filename) + + def load(self, filename): + self.filename = filename + if self.name is None: + self.name = os.path.splitext(os.path.basename(filename))[0] + + state_dict = torch.load(filename, map_location='cpu') + + for size, sd in state_dict.items(): + if type(size) == int: + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + self.name = state_dict.get('name', self.name) + self.step = state_dict.get('step', 0) + self.sd_checkpoint = state_dict.get('sd_checkpoint', None) + self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) + + +def list_hypernetworks(path): + res = {} + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res + + +def load_hypernetwork(filename): + path = shared.hypernetworks.get(filename, None) + if path is not None: + print(f"Loading hypernetwork {filename}") + try: + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + + except Exception: + print(f"Error loading hypernetwork {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + + shared.loaded_hypernetwork = None + + +def apply_hypernetwork(hypernetwork, context, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is None: + return context, context + + if layer is not None: + layer.hyper_k = hypernetwork_layers[0] + layer.hyper_v = hypernetwork_layers[1] + + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v + + +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) + k = self.to_k(context_k) + v = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): + assert hypernetwork_name, 'embedding not selected' + + path = shared.hypernetworks.get(hypernetwork_name, None) + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + + shared.state.textinfo = "Initializing hypernetwork training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + + if save_hypernetwork_every > 0: + hypernetwork_dir = os.path.join(log_directory, "hypernetworks") + os.makedirs(hypernetwork_dir, exist_ok=True) + else: + hypernetwork_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + cond_model = shared.sd_model.cond_stage_model + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + + hypernetwork = shared.loaded_hypernetwork + weights = hypernetwork.weights() + for weight in weights: + weight.requires_grad = True + + optimizer = torch.optim.AdamW(weights, lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = hypernetwork.step or 0 + if ititial_step > steps: + return hypernetwork, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) + for i, (x, text) in pbar: + hypernetwork.step = i + ititial_step + + if hypernetwork.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([text]) + + x = x.to(devices.device) + loss = shared.sd_model(x.unsqueeze(0), c)[0] + del x + + losses[hypernetwork.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') + hypernetwork.save(last_saved_file) + + if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + + preview_text = text if preview_image_prompt == "" else preview_image_prompt + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=preview_text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = hypernetwork.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {hypernetwork.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + checkpoint = sd_models.select_checkpoint() + + hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint_name = checkpoint.model_name + hypernetwork.save(filename) + + return hypernetwork, filename + + diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py new file mode 100644 index 00000000..811bc31e --- /dev/null +++ b/modules/hypernetworks/ui.py @@ -0,0 +1,43 @@ +import html +import os + +import gradio as gr + +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess +from modules import sd_hijack, shared +from modules.hypernetworks import hypernetwork + + +def create_hypernetwork(name): + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet.save(fn) + + shared.reload_hypernetworks() + + return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" + + +def train_hypernetwork(*args): + + initial_hypernetwork = shared.loaded_hypernetwork + + try: + sd_hijack.undo_optimizations() + + hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. +Hypernetwork saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + shared.loaded_hypernetwork = initial_hypernetwork + sd_hijack.apply_optimizations() + diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f873049a..f07ec041 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,7 +37,7 @@ def apply_optimizations(): def undo_optimizations(): - from modules.hypernetwork import hypernetwork + from modules.hypernetworks import hypernetwork ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 27e571fc..3349b9c3 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from ldm.util import default from einops import rearrange from modules import shared -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: diff --git a/modules/shared.py b/modules/shared.py index 375e3afb..1dc2ccf2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.sd_models import modules.styles import modules.devices as devices from modules import sd_samplers -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') diff --git a/modules/ui.py b/modules/ui.py index f57f32db..42e5d866 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -39,7 +39,7 @@ import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui -import modules.hypernetwork.ui +import modules.hypernetworks.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 16918c99..cddb192a 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,7 @@ import modules.scripts as scripts import gradio as gr from modules import images -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared diff --git a/webui.py b/webui.py index ba2156c8..faa38a0d 100644 --- a/webui.py +++ b/webui.py @@ -29,7 +29,7 @@ from modules import devices from modules import modelloader from modules.paths import script_path from modules.shared import cmd_opts -import modules.hypernetwork.hypernetwork +import modules.hypernetworks.hypernetwork modelloader.cleanup_models() modules.sd_models.setup_model() -- cgit v1.2.3 From 5ba23cb41f28f5856a7f64cb0d95e1e94dce90af Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 17:28:17 +0300 Subject: change default for XY plot's Y to Nothing. --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index cddb192a..ef431105 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -197,7 +197,7 @@ class Script(scripts.Script): x_values = gr.Textbox(label="X values", visible=False, lines=1) with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[4].label, visible=False, type="index", elem_id="y_type") + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="y_type") y_values = gr.Textbox(label="Y values", visible=False, lines=1) draw_legend = gr.Checkbox(label='Draw legend', value=True) -- cgit v1.2.3 From 2d006ce16cd95d587533656c3ac4991495e96f23 Mon Sep 17 00:00:00 2001 From: Milly Date: Mon, 10 Oct 2022 00:56:36 +0900 Subject: xy_grid: Find hypernetwork by closest name --- modules/hypernetworks/hypernetwork.py | 11 +++++++++++ scripts/xy_grid.py | 6 +++++- 2 files changed, 16 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 470659df..8f2192e2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -120,6 +120,17 @@ def load_hypernetwork(filename): shared.loaded_hypernetwork = None +def find_closest_hypernetwork_name(search: str): + if not search: + return None + search = search.lower() + applicable = [name for name in shared.hypernetworks if search in name.lower()] + if not applicable: + return None + applicable = sorted(applicable, key=lambda name: len(name)) + return applicable[0] + + def apply_hypernetwork(hypernetwork, context, layer=None): hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index ef431105..6f4217ec 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -84,7 +84,11 @@ def apply_checkpoint(p, x, xs): def apply_hypernetwork(p, x, xs): - hypernetwork.load_hypernetwork(x) + if x.lower() in ["", "none"]: + name = None + else: + name = hypernetwork.find_closest_hypernetwork_name(x) + hypernetwork.load_hypernetwork(name) def apply_clip_skip(p, x, xs): -- cgit v1.2.3 From 7dba1c07cb337114507d9c256f9b843162c187d6 Mon Sep 17 00:00:00 2001 From: Milly Date: Mon, 10 Oct 2022 01:37:09 +0900 Subject: xy_grid: Confirm that hypernetwork options are valid before starting --- scripts/xy_grid.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6f4217ec..b2239d0a 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -88,9 +88,19 @@ def apply_hypernetwork(p, x, xs): name = None else: name = hypernetwork.find_closest_hypernetwork_name(x) + if not name: + raise RuntimeError(f"Unknown hypernetwork: {x}") hypernetwork.load_hypernetwork(name) +def confirm_hypernetworks(xs): + for x in xs: + if x.lower() in ["", "none"]: + continue + if not hypernetwork.find_closest_hypernetwork_name(x): + raise RuntimeError(f"Unknown hypernetwork: {x}") + + def apply_clip_skip(p, x, xs): opts.data["CLIP_stop_at_last_layers"] = x @@ -284,6 +294,8 @@ class Script(scripts.Script): for ckpt_val in valslist: if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None: raise RuntimeError(f"Checkpoint for {ckpt_val} not found") + elif opt.label == "Hypernetwork": + confirm_hypernetworks(valslist) return valslist -- cgit v1.2.3 From 2fffd4bddce12b2c98a5bae5a2cc6d64450d65a0 Mon Sep 17 00:00:00 2001 From: Milly Date: Mon, 10 Oct 2022 02:20:35 +0900 Subject: xy_grid: Refactor confirm functions --- scripts/xy_grid.py | 73 +++++++++++++++++++++++++++++------------------------- 1 file changed, 39 insertions(+), 34 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index b2239d0a..3bb080bf 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -77,12 +77,26 @@ def apply_sampler(p, x, xs): p.sampler_index = sampler_index +def confirm_samplers(p, xs): + samplers_dict = build_samplers_dict(p) + for x in xs: + if x.lower() not in samplers_dict.keys(): + raise RuntimeError(f"Unknown sampler: {x}") + + def apply_checkpoint(p, x, xs): info = modules.sd_models.get_closet_checkpoint_match(x) - assert info is not None, f'Checkpoint for {x} not found' + if info is None: + raise RuntimeError(f"Unknown checkpoint: {x}") modules.sd_models.reload_model_weights(shared.sd_model, info) +def confirm_checkpoints(p, xs): + for x in xs: + if modules.sd_models.get_closet_checkpoint_match(x) is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + + def apply_hypernetwork(p, x, xs): if x.lower() in ["", "none"]: name = None @@ -93,7 +107,7 @@ def apply_hypernetwork(p, x, xs): hypernetwork.load_hypernetwork(name) -def confirm_hypernetworks(xs): +def confirm_hypernetworks(p, xs): for x in xs: if x.lower() in ["", "none"]: continue @@ -135,29 +149,29 @@ def str_permutations(x): return x -AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) -AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) +AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"]) +AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"]) axis_options = [ - AxisOption("Nothing", str, do_nothing, format_nothing), - AxisOption("Seed", int, apply_field("seed"), format_value_add_label), - AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label), - AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label), - AxisOption("Steps", int, apply_field("steps"), format_value_add_label), - AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), - AxisOption("Prompt S/R", str, apply_prompt, format_value), - AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), - AxisOption("Sampler", str, apply_sampler, format_value), - AxisOption("Checkpoint name", str, apply_checkpoint, format_value), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value), - AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), - AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), - AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), - AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), - AxisOption("Eta", float, apply_field("eta"), format_value_add_label), - AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label), - AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones + AxisOption("Nothing", str, do_nothing, format_nothing, None), + AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None), + AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None), + AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None), + AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None), + AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None), + AxisOption("Prompt S/R", str, apply_prompt, format_value, None), + AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None), + AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks), + AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None), + AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None), + AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None), + AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None), + AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), + AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), + AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -283,19 +297,10 @@ class Script(scripts.Script): valslist = list(permutations(valslist)) valslist = [opt.type(x) for x in valslist] - + # Confirm options are valid before starting - if opt.label == "Sampler": - samplers_dict = build_samplers_dict(p) - for sampler_val in valslist: - if sampler_val.lower() not in samplers_dict.keys(): - raise RuntimeError(f"Unknown sampler: {sampler_val}") - elif opt.label == "Checkpoint name": - for ckpt_val in valslist: - if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None: - raise RuntimeError(f"Checkpoint for {ckpt_val} not found") - elif opt.label == "Hypernetwork": - confirm_hypernetworks(valslist) + if opt.confirm: + opt.confirm(p, valslist) return valslist -- cgit v1.2.3 From 94c01aa35656130b56f401830ad443ce3d97c364 Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Tue, 11 Oct 2022 17:05:20 -0700 Subject: draw_xy_grid provides the option to also return lone images --- scripts/xy_grid.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 3bb080bf..14edacc1 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -175,8 +175,9 @@ axis_options = [ ] -def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): +def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images): res = [] + successful_images = [] ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] @@ -194,7 +195,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): first_processed = processed try: - res.append(processed.images[0]) + processed_image = processed.images[0] + res.append(processed_image) + successful_images.append(processed_image) except: res.append(Image.new(res[0].mode, res[0].size)) @@ -203,6 +206,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) first_processed.images = [grid] + if include_lone_images: + first_processed.images += successful_images return first_processed @@ -229,11 +234,12 @@ class Script(scripts.Script): y_values = gr.Textbox(label="Y values", visible=False, lines=1) draw_legend = gr.Checkbox(label='Draw legend', value=True) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=True) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False) - return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds] + return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] - def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds): + def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -344,7 +350,8 @@ class Script(scripts.Script): x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], cell=cell, - draw_legend=draw_legend + draw_legend=draw_legend, + include_lone_images=include_lone_images ) if opts.grid_save: -- cgit v1.2.3 From 8711c2fe0135d5c160a57db41cb79ed1942ce7fa Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Wed, 12 Oct 2022 16:12:12 -0700 Subject: Fix metadata contents --- modules/ui.py | 5 +---- scripts/xy_grid.py | 52 ++++++++++++++++++++++++++++++++++------------------ 2 files changed, 35 insertions(+), 22 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/ui.py b/modules/ui.py index 4fa405a9..e07ee0e1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -148,10 +148,7 @@ def save_files(js_data, images, do_make_zip, index): is_grid = image_index < p.index_of_first_image i = 0 if is_grid else (image_index - p.index_of_first_image) - seed = p.all_seeds[i] if len(p.all_seeds) > 1 else p.seed - prompt = p.all_prompts[i] if len(p.all_prompts) > 1 else p.prompt - info = p.infotexts[image_index] if len(p.infotexts) > 1 else p.infotexts[0] - fullfn, txt_fullfn = save_image(image, path, "", seed=seed, prompt=prompt, extension=extension, info=info, grid=is_grid, p=p, save_to_dirs=save_to_dirs) + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 14edacc1..02931ae6 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -176,13 +176,16 @@ axis_options = [ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images): - res = [] - successful_images = [] - ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] - first_processed = None + # Temporary list of all the images that are generated to be populated into the grid. + # Will be filled with empty images for any individual step that fails to process properly + image_cache = [] + + processed_result = None + cell_mode = "P" + cell_size = (1,1) state.job_count = len(xs) * len(ys) * p.n_iter @@ -190,26 +193,39 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ for ix, x in enumerate(xs): state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" - processed = cell(x, y) - if first_processed is None: - first_processed = processed - + processed:Processed = cell(x, y) try: - processed_image = processed.images[0] - res.append(processed_image) - successful_images.append(processed_image) + # this dereference will throw an exception if the image was not processed + # (this happens in cases such as if the user stops the process from the UI) + processed_image = processed.images[0] + + if processed_result is None: + # Use our first valid processed result as a template container to hold our full results + processed_result = copy(processed) + cell_mode = processed_image.mode + cell_size = processed_image.size + processed_result.images = [Image.new(cell_mode, cell_size)] + + image_cache.append(processed_image) + if include_lone_images: + processed_result.images.append(processed_image) + processed_result.all_prompts.append(processed.prompt) + processed_result.all_seeds.append(processed.seed) + processed_result.infotexts.append(processed.infotexts[0]) except: - res.append(Image.new(res[0].mode, res[0].size)) + image_cache.append(Image.new(cell_mode, cell_size)) + + if not processed_result: + print("Unexpected error: draw_xy_grid failed to return even a single processed image") + return Processed() - grid = images.image_grid(res, rows=len(ys)) + grid = images.image_grid(image_cache, rows=len(ys)) if draw_legend: - grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) + grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) - first_processed.images = [grid] - if include_lone_images: - first_processed.images += successful_images + processed_result.images[0] = grid - return first_processed + return processed_result re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -- cgit v1.2.3 From 354ef0da3b1f0fa5c113d04b6c79e3908c848d23 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 20:12:37 +0300 Subject: add hypernetwork multipliers --- modules/hypernetworks/hypernetwork.py | 8 +++++++- modules/shared.py | 5 ++++- modules/ui.py | 5 ++++- scripts/xy_grid.py | 9 ++++++++- style.css | 3 +++ webui.py | 2 +- 6 files changed, 27 insertions(+), 5 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b6c06d49..f1248bb7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): + multiplier = 1.0 + def __init__(self, dim, state_dict=None): super().__init__() @@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module): self.to(devices.device) def forward(self, x): - return x + (self.linear2(self.linear1(x))) + return x + (self.linear2(self.linear1(x))) * self.multiplier + + +def apply_strength(value=None): + HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength class Hypernetwork: diff --git a/modules/shared.py b/modules/shared.py index d8e3a286..5901e605 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -238,7 +238,8 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), - "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), @@ -348,6 +349,8 @@ class Options: item = self.data_labels.get(key) item.onchange = func + func() + def dumpjson(self): d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} return json.dumps(d) diff --git a/modules/ui.py b/modules/ui.py index 0a58f6be..673014f2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1244,7 +1244,10 @@ def create_ui(wrap_gradio_gpu_call): def refresh(): info.refresh() refreshed_args = info.component_args() if callable(info.component_args) else info.component_args - res.choices = refreshed_args["choices"] + + for k, v in refreshed_args.items(): + setattr(res, k, v) + return gr.update(**(refreshed_args or {})) refresh_button.click( diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 02931ae6..efb63af5 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -107,6 +107,10 @@ def apply_hypernetwork(p, x, xs): hypernetwork.load_hypernetwork(name) +def apply_hypernetwork_strength(p, x, xs): + hypernetwork.apply_strength(x) + + def confirm_hypernetworks(p, xs): for x in xs: if x.lower() in ["", "none"]: @@ -165,6 +169,7 @@ axis_options = [ AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers), AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints), AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks), + AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None), @@ -250,7 +255,7 @@ class Script(scripts.Script): y_values = gr.Textbox(label="Y values", visible=False, lines=1) draw_legend = gr.Checkbox(label='Draw legend', value=True) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=True) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False) return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] @@ -377,6 +382,8 @@ class Script(scripts.Script): modules.sd_models.reload_model_weights(shared.sd_model) hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + hypernetwork.apply_strength() + opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers diff --git a/style.css b/style.css index ad2a52cc..aa3d379c 100644 --- a/style.css +++ b/style.css @@ -522,6 +522,9 @@ canvas[key="mask"] { z-index: 200; width: 8em; } +#quicksettings .gr-box > div > div > input.gr-text-input { + top: -1.12em; +} .row.gr-compact{ overflow: visible; diff --git a/webui.py b/webui.py index 33ba7905..fe0ce321 100644 --- a/webui.py +++ b/webui.py @@ -72,7 +72,6 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) - def initialize(): modelloader.cleanup_models() modules.sd_models.setup_model() @@ -86,6 +85,7 @@ def initialize(): shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) + shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) def webui(): -- cgit v1.2.3 From 02382f7ce462a360e8aea9ee3178da48b564f70a Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Wed, 12 Oct 2022 16:35:36 -0700 Subject: regression in xy_grid Var. seed fixing --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index efb63af5..5700b007 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -338,7 +338,7 @@ class Script(scripts.Script): ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label == 'Seed': + if axis_opt.label in ['Seed','Var. seed']: return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] else: return axis_list -- cgit v1.2.3 From 4cc37e4cdf2ce5f5b753786b55ae1d4abd530c01 Mon Sep 17 00:00:00 2001 From: Naeaeaeaeae Date: Thu, 13 Oct 2022 18:49:58 +0200 Subject: [xy_grid.py] add option denoising_strength --- scripts/xy_grid.py | 1 + 1 file changed, 1 insertion(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 5700b007..fda2b71d 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -176,6 +176,7 @@ axis_options = [ AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), + AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] -- cgit v1.2.3 From 989a552de3d1fcd1f178fe873713b884e192dd61 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 22:04:08 +0300 Subject: remove the other Denoising --- scripts/xy_grid.py | 1 - 1 file changed, 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index fda2b71d..8c7da6bb 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -177,7 +177,6 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] -- cgit v1.2.3 From a8f7722e4e7460122b44589c3718eee0c597009d Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 14 Oct 2022 14:26:38 -0700 Subject: Fix XY-plot steps if highres fix is enabled --- scripts/xy_grid.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8c7da6bb..88ad3bf7 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -12,7 +12,7 @@ import gradio as gr from modules import images from modules.hypernetworks import hypernetwork -from modules.processing import process_images, Processed, get_correct_sampler +from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.sd_samplers @@ -354,6 +354,9 @@ class Script(scripts.Script): else: total_steps = p.steps * len(xs) * len(ys) + if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: + total_steps *= 2 + print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") shared.total_tqdm.updateTotal(total_steps * p.n_iter) -- cgit v1.2.3