From 1c5604791da7e57f40880698666b6617a1754c65 Mon Sep 17 00:00:00 2001 From: DoTheSneedful Date: Mon, 3 Oct 2022 22:20:09 -0400 Subject: Add a prompt order option to XY plot script --- scripts/xy_grid.py | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 146663b0..044c30e6 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -1,5 +1,6 @@ from collections import namedtuple from copy import copy +from itertools import permutations import random from PIL import Image @@ -28,6 +29,27 @@ def apply_prompt(p, x, xs): p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) +def apply_order(p, x, xs): + token_order = [] + + # Initally grab the tokens from the prompt so they can be later be replaced in order of earliest seen in the prompt + for token in x: + token_order.append((p.prompt.find(token), token)) + + token_order.sort(key=lambda t: t[0]) + + search_from_pos = 0 + for idx, token in enumerate(x): + original_pos, old_token = token_order[idx] + + # Get position of the token again as it will likely change as tokens are being replaced + pos = p.prompt.find(old_token) + if original_pos >= 0: + # Avoid trying to replace what was just replaced by searching later in the prompt string + p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, token, 1) + + search_from_pos = pos + len(token) + samplers_dict = {} for i, sampler in enumerate(modules.sd_samplers.samplers): @@ -60,7 +82,8 @@ def format_value_add_label(p, opt, x): def format_value(p, opt, x): if type(x) == float: x = round(x, 8) - + if type(x) == type(list()): + x = str(x) return x def do_nothing(p, x, xs): @@ -89,6 +112,7 @@ axis_options = [ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label), + AxisOption("Prompt order", type(list()), apply_order, format_value), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -159,7 +183,11 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = [x.strip() for x in vals.split(",")] + if opt.type == type(list()): + valslist = [x for x in vals] + else: + valslist = [x.strip() for x in vals.split(",")] + if opt.type == int: valslist_ext = [] @@ -212,9 +240,17 @@ class Script(scripts.Script): return valslist x_opt = axis_options[x_type] + + if x_opt.label == "Prompt order": + x_values = list(permutations([x.strip() for x in x_values.split(",")])) + xs = process_axis(x_opt, x_values) y_opt = axis_options[y_type] + + if y_opt.label == "Prompt order": + y_values = list(permutations([y.strip() for y in y_values.split(",")])) + ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): -- cgit v1.2.3 From 1a6d40db35656083d5bf9d3a3430b45fda4e85eb Mon Sep 17 00:00:00 2001 From: DoTheSneedful Date: Tue, 4 Oct 2022 00:18:15 -0400 Subject: Fix token ordering in prompt order XY plot --- scripts/xy_grid.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 044c30e6..5bcd3921 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -32,24 +32,21 @@ def apply_prompt(p, x, xs): def apply_order(p, x, xs): token_order = [] - # Initally grab the tokens from the prompt so they can be later be replaced in order of earliest seen in the prompt + # Initally grab the tokens from the prompt so they can be be replaced in order of earliest seen for token in x: token_order.append((p.prompt.find(token), token)) token_order.sort(key=lambda t: t[0]) search_from_pos = 0 - for idx, token in enumerate(x): - original_pos, old_token = token_order[idx] - + for idx, (original_pos, old_token) in enumerate(token_order): # Get position of the token again as it will likely change as tokens are being replaced - pos = p.prompt.find(old_token) + pos = search_from_pos + p.prompt[search_from_pos:].find(old_token) if original_pos >= 0: # Avoid trying to replace what was just replaced by searching later in the prompt string - p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, token, 1) - - search_from_pos = pos + len(token) + p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, x[idx], 1) + search_from_pos = pos + len(x[idx]) samplers_dict = {} for i, sampler in enumerate(modules.sd_samplers.samplers): -- cgit v1.2.3 From 56371153b545e3a43c3a5f206264019af361f3af Mon Sep 17 00:00:00 2001 From: DoTheSneedful Date: Tue, 4 Oct 2022 01:07:36 -0400 Subject: XY plot prompt order simplify logic --- scripts/xy_grid.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 5bcd3921..7def47f5 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -38,15 +38,21 @@ def apply_order(p, x, xs): token_order.sort(key=lambda t: t[0]) - search_from_pos = 0 - for idx, (original_pos, old_token) in enumerate(token_order): - # Get position of the token again as it will likely change as tokens are being replaced - pos = search_from_pos + p.prompt[search_from_pos:].find(old_token) - if original_pos >= 0: - # Avoid trying to replace what was just replaced by searching later in the prompt string - p.prompt = p.prompt[0:search_from_pos] + p.prompt[search_from_pos:].replace(old_token, x[idx], 1) - - search_from_pos = pos + len(x[idx]) + prompt_parts = [] + + # Split the prompt up, taking out the tokens + for _, token in token_order: + n = p.prompt.find(token) + prompt_parts.append(p.prompt[0:n]) + p.prompt = p.prompt[n + len(token):] + + # Rebuild the prompt with the tokens in the order we want + prompt_tmp = "" + for idx, part in enumerate(prompt_parts): + prompt_tmp += part + prompt_tmp += x[idx] + p.prompt = prompt_tmp + p.prompt + samplers_dict = {} for i, sampler in enumerate(modules.sd_samplers.samplers): -- cgit v1.2.3 From 556c36b9607e3f4eacdddc85f8e7a78b29476ea7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 09:18:00 +0300 Subject: add hint, refactor code for #1607 --- javascript/hints.js | 1 + scripts/xy_grid.py | 35 ++++++++++++++++++----------------- 2 files changed, 19 insertions(+), 17 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/javascript/hints.js b/javascript/hints.js index e72e9338..8adcd983 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -47,6 +47,7 @@ titles = { "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", + "Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order", "Tiling": "Produce an image that can be tiled.", "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.", diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 7def47f5..1237e754 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -29,10 +29,11 @@ def apply_prompt(p, x, xs): p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) + def apply_order(p, x, xs): token_order = [] - # Initally grab the tokens from the prompt so they can be be replaced in order of earliest seen + # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen for token in x: token_order.append((p.prompt.find(token), token)) @@ -85,17 +86,26 @@ def format_value_add_label(p, opt, x): def format_value(p, opt, x): if type(x) == float: x = round(x, 8) - if type(x) == type(list()): - x = str(x) return x + +def format_value_join_list(p, opt, x): + return ", ".join(x) + + def do_nothing(p, x, xs): pass + def format_nothing(p, opt, x): return "" +def str_permutations(x): + """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" + return x + + AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) @@ -108,6 +118,7 @@ axis_options = [ AxisOption("Steps", int, apply_field("steps"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("Prompt S/R", str, apply_prompt, format_value), + AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), @@ -115,7 +126,6 @@ axis_options = [ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label), - AxisOption("Prompt order", type(list()), apply_order, format_value), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -158,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") + class Script(scripts.Script): def title(self): return "X/Y plot" @@ -186,11 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - if opt.type == type(list()): - valslist = [x for x in vals] - else: - valslist = [x.strip() for x in vals.split(",")] - + valslist = [x.strip() for x in vals.split(",")] if opt.type == int: valslist_ext = [] @@ -237,23 +244,17 @@ class Script(scripts.Script): valslist_ext.append(val) valslist = valslist_ext + elif opt.type == str_permutations: + valslist = list(permutations(valslist)) valslist = [opt.type(x) for x in valslist] return valslist x_opt = axis_options[x_type] - - if x_opt.label == "Prompt order": - x_values = list(permutations([x.strip() for x in x_values.split(",")])) - xs = process_axis(x_opt, x_values) y_opt = axis_options[y_type] - - if y_opt.label == "Prompt order": - y_values = list(permutations([y.strip() for y in y_values.split(",")])) - ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): -- cgit v1.2.3 From 5d0e6ab8567bda2ee8f5ed31f332ca07c1b84b98 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 04:04:50 +0100 Subject: Allow escaping of commas in xy_grid --- scripts/xy_grid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 1237e754..210829a7 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -168,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") +re_non_escaped_comma = re.compile(r"(? Date: Thu, 6 Oct 2022 11:55:21 +0100 Subject: use csv.reader --- scripts/xy_grid.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 210829a7..1a625898 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -1,8 +1,9 @@ from collections import namedtuple from copy import copy -from itertools import permutations +from itertools import permutations, chain import random - +import csv +from io import StringIO from PIL import Image import numpy as np @@ -168,8 +169,6 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") -re_non_escaped_comma = re.compile(r"(? Date: Thu, 6 Oct 2022 12:32:17 +0100 Subject: strip() split comma delimited lines --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 1a625898..ec27e58b 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -197,7 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = list(chain.from_iterable(csv.reader(StringIO(s)))) + valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(s))))) if opt.type == int: valslist_ext = [] -- cgit v1.2.3 From 82eb8ea452b1e63535c58d15ec6db2ad2342faa8 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 15:22:51 +0100 Subject: Update xy_grid.py split vals not 's' from tests --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index ec27e58b..210c7b6e 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -197,7 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(s))))) + valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(vals))))) if opt.type == int: valslist_ext = [] -- cgit v1.2.3 From 1069ec49a35d04c1e85c92534e92a2d6aa59cb75 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 20:16:21 +0300 Subject: revert back to using list comprehension rather than list and map --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 210c7b6e..6344e612 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -197,7 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - valslist = list(map(str.strip,chain.from_iterable(csv.reader(StringIO(vals))))) + valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))] if opt.type == int: valslist_ext = [] -- cgit v1.2.3 From bad7cb29cecac51c5c0f39afec332b007ed73133 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 10:17:52 +0300 Subject: added support for hypernetworks (???) --- modules/hypernetwork.py | 55 ++++++++++++++++++++++++++++++++++++++ modules/sd_hijack_optimizations.py | 17 ++++++++++-- modules/shared.py | 9 ++++++- scripts/xy_grid.py | 10 +++++++ 4 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 modules/hypernetwork.py (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py new file mode 100644 index 00000000..9ed1eed9 --- /dev/null +++ b/modules/hypernetwork.py @@ -0,0 +1,55 @@ +import glob +import os +import torch +from modules import devices + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + self.load_state_dict(state_dict, strict=True) + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, filename): + self.filename = filename + self.name = os.path.splitext(os.path.basename(filename))[0] + self.layers = {} + + state_dict = torch.load(filename, map_location='cpu') + for size, sd in state_dict.items(): + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + +def load_hypernetworks(path): + res = {} + + for filename in glob.iglob(path + '**/*.pt', recursive=True): + hn = Hypernetwork(filename) + res[hn.name] = hn + + return res + +def apply(self, x, context=None, mask=None, original=None): + + + if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: + if context.shape[1] == 77 and CrossAttention.noise_cond: + context = context + (torch.randn_like(context) * 0.1) + h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] + k = self.to_k(h_k(context)) + v = self.to_v(h_v(context)) + else: + k = self.to_k(context) + v = self.to_v(context) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfc..d9cca485 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,8 @@ from torch import einsum from ldm.util import default from einops import rearrange +from modules import shared + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): @@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) + + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) + + k_in *= self.scale + del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) diff --git a/modules/shared.py b/modules/shared.py index 25bb6e6c..879d8424 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers +from modules import sd_samplers, hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -76,6 +76,12 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram config_filename = cmd_opts.ui_settings_file +hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) + + +def selected_hypernetwork(): + return hypernetworks.get(opts.sd_hypernetwork, None) + class State: interrupted = False @@ -206,6 +212,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6344e612..c0c364df 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -77,6 +77,11 @@ def apply_checkpoint(p, x, xs): modules.sd_models.reload_model_weights(shared.sd_model, info) +def apply_hypernetwork(p, x, xs): + hn = shared.hypernetworks.get(x, None) + opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -122,6 +127,7 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), @@ -193,6 +199,8 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 + initial_hn = opts.sd_hypernetwork + def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -300,4 +308,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) + opts.data["sd_hypernetwork"] = initial_hn + return processed -- cgit v1.2.3 From 03e570886f430f39020e504aba057a95f2e62484 Mon Sep 17 00:00:00 2001 From: frostydad <64224601+Cyberes@users.noreply.github.com> Date: Sat, 8 Oct 2022 18:13:13 -0600 Subject: Fix incorrect sampler name in output --- modules/processing.py | 9 ++++++++- scripts/xy_grid.py | 16 +++++++++------- 2 files changed, 17 insertions(+), 8 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/processing.py b/modules/processing.py index 4fea6d56..6b8664a0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,3 +1,4 @@ + import json import math import os @@ -46,6 +47,12 @@ def apply_color_correction(correction, image): return image +def get_correct_sampler(p): + if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img): + return sd_samplers.samplers + elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): + return sd_samplers.samplers_for_img2img + class StableDiffusionProcessing: def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): self.sd_model = sd_model @@ -272,7 +279,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params = { "Steps": p.steps, - "Sampler": sd_samplers.samplers[p.sampler_index].name, + "Sampler": get_correct_sampler(p)[p.sampler_index].name, "CFG scale": p.cfg_scale, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index c0c364df..26ae2199 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,7 @@ import modules.scripts as scripts import gradio as gr from modules import images -from modules.processing import process_images, Processed +from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.sd_samplers @@ -56,15 +56,17 @@ def apply_order(p, x, xs): p.prompt = prompt_tmp + p.prompt -samplers_dict = {} -for i, sampler in enumerate(modules.sd_samplers.samplers): - samplers_dict[sampler.name.lower()] = i - for alias in sampler.aliases: - samplers_dict[alias.lower()] = i +def build_samplers_dict(p): + samplers_dict = {} + for i, sampler in enumerate(get_correct_sampler(p)): + samplers_dict[sampler.name.lower()] = i + for alias in sampler.aliases: + samplers_dict[alias.lower()] = i + return samplers_dict def apply_sampler(p, x, xs): - sampler_index = samplers_dict.get(x.lower(), None) + sampler_index = build_samplers_dict(p).get(x.lower(), None) if sampler_index is None: raise RuntimeError(f"Unknown sampler: {x}") -- cgit v1.2.3 From d74c38108f95e44d83a1706ee5ab218124972868 Mon Sep 17 00:00:00 2001 From: Jesse Williams <33797815+xram64@users.noreply.github.com> Date: Sat, 8 Oct 2022 01:30:49 -0400 Subject: Confirm that options are valid before starting When using the 'Sampler' or 'Checkpoint' options, if one of the entered names has a typo, an error will only be thrown once the `draw_xy_grid` loop reaches that name. This can waste a lot of time for large grids with a typo near the end of a list, since the script needs to start over and re-generate any earlier images to finish making the grid. Also fixing typo in variable name in `draw_xy_grid`. --- scripts/xy_grid.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 26ae2199..07040886 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -145,7 +145,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] - first_pocessed = None + first_processed = None state.job_count = len(xs) * len(ys) * p.n_iter @@ -154,8 +154,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" processed = cell(x, y) - if first_pocessed is None: - first_pocessed = processed + if first_processed is None: + first_processed = processed try: res.append(processed.images[0]) @@ -166,9 +166,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): if draw_legend: grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) - first_pocessed.images = [grid] + first_processed.images = [grid] - return first_pocessed + return first_processed re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") @@ -216,7 +216,6 @@ class Script(scripts.Script): m = re_range.fullmatch(val) mc = re_range_count.fullmatch(val) if m is not None: - start = int(m.group(1)) end = int(m.group(2))+1 step = int(m.group(3)) if m.group(3) is not None else 1 @@ -258,6 +257,16 @@ class Script(scripts.Script): valslist = list(permutations(valslist)) valslist = [opt.type(x) for x in valslist] + + # Confirm options are valid before starting + if opt.label == "Sampler": + for sampler_val in valslist: + if sampler_val.lower() not in samplers_dict.keys(): + raise RuntimeError(f"Unknown sampler: {sampler_val}") + elif opt.label == "Checkpoint name": + for ckpt_val in valslist: + if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None: + raise RuntimeError(f"Checkpoint for {ckpt_val} not found") return valslist -- cgit v1.2.3 From a65a45272e8f26ee3bc52a5300b396266508a9a5 Mon Sep 17 00:00:00 2001 From: Brendan Byrd Date: Thu, 6 Oct 2022 19:31:36 -0400 Subject: Don't change the seed initially if "Keep -1 for seeds" is checked Fixes #1049 --- scripts/xy_grid.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 07040886..a8f53bef 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -198,7 +198,9 @@ class Script(scripts.Script): return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds] def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds): - modules.processing.fix_seed(p) + if not no_fixed_seeds: + modules.processing.fix_seed(p) + p.batch_size = 1 initial_hn = opts.sd_hypernetwork -- cgit v1.2.3 From 542a3d3a4a00c1383fbdaf938ceefef87cf834bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:33:22 +0300 Subject: fix btoken hypernetworks in XY plot --- modules/hypernetwork.py | 7 +++++-- scripts/xy_grid.py | 9 +++------ 2 files changed, 8 insertions(+), 8 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 19f1c227..498bc9d8 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -49,15 +49,18 @@ def list_hypernetworks(path): def load_hypernetwork(filename): - print(f"Loading hypernetwork {filename}") path = shared.hypernetworks.get(filename, None) - if (path is not None): + if path is not None: + print(f"Loading hypernetwork {filename}") try: shared.loaded_hypernetwork = Hypernetwork(path) except Exception: print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + shared.loaded_hypernetwork = None diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index a8f53bef..fe949067 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images +from modules import images, hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs): def apply_hypernetwork(p, x, xs): - hn = shared.hypernetworks.get(x, None) - opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + hypernetwork.load_hypernetwork(x) def format_value_add_label(p, opt, x): @@ -203,8 +202,6 @@ class Script(scripts.Script): p.batch_size = 1 - initial_hn = opts.sd_hypernetwork - def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -321,6 +318,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) - opts.data["sd_hypernetwork"] = initial_hn + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) return processed -- cgit v1.2.3 From 2c52f4da7ff80a3ec277105f4db6146c6379898a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 15:01:42 +0300 Subject: fix broken samplers in XY plot --- scripts/xy_grid.py | 1 + 1 file changed, 1 insertion(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index fe949067..c89ca1a9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -259,6 +259,7 @@ class Script(scripts.Script): # Confirm options are valid before starting if opt.label == "Sampler": + samplers_dict = build_samplers_dict(p) for sampler_val in valslist: if sampler_val.lower() not in samplers_dict.keys(): raise RuntimeError(f"Unknown sampler: {sampler_val}") -- cgit v1.2.3 From 45bf9a6264b3507473e02cc3f9aa36559f24aca2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 18:58:55 +0300 Subject: added clip skip to XY plot --- scripts/xy_grid.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index c89ca1a9..7b0d9083 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -83,6 +83,10 @@ def apply_hypernetwork(p, x, xs): hypernetwork.load_hypernetwork(x) +def apply_clip_skip(p, x, xs): + opts.data["CLIP_ignore_last_layers"] = x + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -134,6 +138,7 @@ axis_options = [ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label), + AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -201,6 +206,7 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 + CLIP_ignore_last_layers = opts.CLIP_ignore_last_layers def process_axis(opt, vals): if opt.label == 'Nothing': @@ -321,4 +327,6 @@ class Script(scripts.Script): hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + opts.data["CLIP_ignore_last_layers"] = CLIP_ignore_last_layers + return processed -- cgit v1.2.3 From 84ddd44113b36062e8ba6cb2e5db0fce4f48efb8 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sun, 9 Oct 2022 14:57:17 -0400 Subject: Clip skip variable name change breaks x/y plot script. This fixes that --- scripts/xy_grid.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 7b0d9083..771eb8e4 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -84,7 +84,7 @@ def apply_hypernetwork(p, x, xs): def apply_clip_skip(p, x, xs): - opts.data["CLIP_ignore_last_layers"] = x + opts.data["CLIP_stop_at_last_layers"] = x def format_value_add_label(p, opt, x): @@ -206,7 +206,7 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 - CLIP_ignore_last_layers = opts.CLIP_ignore_last_layers + CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers def process_axis(opt, vals): if opt.label == 'Nothing': @@ -327,6 +327,6 @@ class Script(scripts.Script): hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - opts.data["CLIP_ignore_last_layers"] = CLIP_ignore_last_layers + opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers return processed -- cgit v1.2.3 From 5da1ba0e91a81804dc911d34c9a2e6956a23199c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 21:24:11 +0300 Subject: remove batch size restriction from X/Y plot --- scripts/xy_grid.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 771eb8e4..42e1489c 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -205,7 +205,10 @@ class Script(scripts.Script): if not no_fixed_seeds: modules.processing.fix_seed(p) - p.batch_size = 1 + if not opts.return_grid: + p.batch_size = 1 + + CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers def process_axis(opt, vals): -- cgit v1.2.3 From 255be75d30f41e089e499ec1c8462d6bf64dec24 Mon Sep 17 00:00:00 2001 From: aperullo <18688190+aperullo@users.noreply.github.com> Date: Tue, 11 Oct 2022 06:16:57 -0400 Subject: Error if prompt missing SR token to prevent mis-gens (#2209) --- scripts/xy_grid.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 42e1489c..10a82dc9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -27,9 +27,16 @@ def apply_field(field): def apply_prompt(p, x, xs): + + orig_prompt = p.prompt + orig_negative_prompt = p.negative_prompt + p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) + if p.prompt == orig_prompt and p.negative_prompt == orig_negative_prompt: + raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt. Did you forget to add the token?") + def apply_order(p, x, xs): token_order = [] -- cgit v1.2.3 From a8490e4019c359ff24824e004059744d7164361b Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 11:42:41 +0100 Subject: revert sr warning --- scripts/xy_grid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 10a82dc9..99b3c4f6 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -35,7 +35,8 @@ def apply_prompt(p, x, xs): p.negative_prompt = p.negative_prompt.replace(xs[0], x) if p.prompt == orig_prompt and p.negative_prompt == orig_negative_prompt: - raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt. Did you forget to add the token?") + pass + #raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt. Did you forget to add the token?") def apply_order(p, x, xs): -- cgit v1.2.3 From 1a0a6a84c3149e236211d547471f5416cd1129f3 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Tue, 11 Oct 2022 11:59:56 +0100 Subject: add incorrect start word guard to xy_grid (#2259) --- scripts/xy_grid.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 99b3c4f6..9d4d6187 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -27,17 +27,12 @@ def apply_field(field): def apply_prompt(p, x, xs): - - orig_prompt = p.prompt - orig_negative_prompt = p.negative_prompt + if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: + raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.") p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) - if p.prompt == orig_prompt and p.negative_prompt == orig_negative_prompt: - pass - #raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt. Did you forget to add the token?") - def apply_order(p, x, xs): token_order = [] -- cgit v1.2.3 From 530103b586109c11fd068eb70ef09503ec6a4caf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 14:53:02 +0300 Subject: fixes related to merge --- modules/hypernetwork.py | 103 ------------------------- modules/hypernetwork/hypernetwork.py | 74 +++++++++++------- modules/hypernetwork/ui.py | 10 +-- modules/sd_hijack_optimizations.py | 3 +- modules/shared.py | 13 +++- modules/textual_inversion/textual_inversion.py | 12 +-- modules/ui.py | 5 +- scripts/xy_grid.py | 3 +- webui.py | 15 +--- 9 files changed, 78 insertions(+), 160 deletions(-) delete mode 100644 modules/hypernetwork.py (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py deleted file mode 100644 index 7bbc443e..00000000 --- a/modules/hypernetwork.py +++ /dev/null @@ -1,103 +0,0 @@ -import glob -import os -import sys -import traceback - -import torch - -from ldm.util import default -from modules import devices, shared -import torch -from torch import einsum -from einops import rearrange, repeat - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, state_dict): - super().__init__() - - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) - - self.load_state_dict(state_dict, strict=True) - self.to(devices.device) - - def forward(self, x): - return x + (self.linear2(self.linear1(x))) - - -class Hypernetwork: - filename = None - name = None - - def __init__(self, filename): - self.filename = filename - self.name = os.path.splitext(os.path.basename(filename))[0] - self.layers = {} - - state_dict = torch.load(filename, map_location='cpu') - for size, sd in state_dict.items(): - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) - - -def list_hypernetworks(path): - res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): - name = os.path.splitext(os.path.basename(filename))[0] - res[name] = filename - return res - - -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - if path is not None: - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork(path) - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print(f"Unloading hypernetwork") - - shared.loaded_hypernetwork = None - - -def apply_hypernetwork(hypernetwork, context): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - - -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context) - k = self.to_k(context_k) - v = self.to_v(context_v) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if mask is not None: - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py index a3d6a47e..aa701bda 100644 --- a/modules/hypernetwork/hypernetwork.py +++ b/modules/hypernetwork/hypernetwork.py @@ -26,10 +26,11 @@ class HypernetworkModule(torch.nn.Module): if state_dict is not None: self.load_state_dict(state_dict, strict=True) else: - self.linear1.weight.data.fill_(0.0001) - self.linear1.bias.data.fill_(0.0001) - self.linear2.weight.data.fill_(0.0001) - self.linear2.bias.data.fill_(0.0001) + + self.linear1.weight.data.normal_(mean=0.0, std=0.01) + self.linear1.bias.data.zero_() + self.linear2.weight.data.normal_(mean=0.0, std=0.01) + self.linear2.bias.data.zero_() self.to(devices.device) @@ -92,41 +93,54 @@ class Hypernetwork: self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) -def load_hypernetworks(path): +def list_hypernetworks(path): res = {} + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res - for filename in glob.iglob(path + '**/*.pt', recursive=True): + +def load_hypernetwork(filename): + path = shared.hypernetworks.get(filename, None) + if path is not None: + print(f"Loading hypernetwork {filename}") try: - hn = Hypernetwork() - hn.load(filename) - res[hn.name] = hn + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + except Exception: - print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") - return res + shared.loaded_hypernetwork = None -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads +def apply_hypernetwork(hypernetwork, context, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - q = self.to_q(x) - context = default(context, x) + if hypernetwork_layers is None: + return context, context - hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None) + if layer is not None: + layer.hyper_k = hypernetwork_layers[0] + layer.hyper_v = hypernetwork_layers[1] - if hypernetwork_layers is not None: - hypernetwork_k, hypernetwork_v = hypernetwork_layers + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v - self.hypernetwork_k = hypernetwork_k - self.hypernetwork_v = hypernetwork_v - context_k = hypernetwork_k(context) - context_v = hypernetwork_v(context) - else: - context_k = context - context_v = context +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) k = self.to_k(context_k) v = self.to_v(context_v) @@ -151,7 +165,9 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): assert hypernetwork_name, 'embedding not selected' - shared.hypernetwork = shared.hypernetworks[hypernetwork_name] + path = shared.hypernetworks.get(hypernetwork_name, None) + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) shared.state.textinfo = "Initializing hypernetwork training..." shared.state.job_count = steps @@ -176,9 +192,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) - hypernetwork = shared.hypernetworks[hypernetwork_name] + hypernetwork = shared.loaded_hypernetwork weights = hypernetwork.weights() for weight in weights: weight.requires_grad = True @@ -194,7 +210,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps: return hypernetwork, filename - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, (x, text) in pbar: hypernetwork.step = i + ititial_step diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py index 525f978c..f6d1d0a3 100644 --- a/modules/hypernetwork/ui.py +++ b/modules/hypernetwork/ui.py @@ -6,24 +6,24 @@ import gradio as gr import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess from modules import sd_hijack, shared +from modules.hypernetwork import hypernetwork def create_hypernetwork(name): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) - hypernetwork.save(fn) + hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet.save(fn) shared.reload_hypernetworks() - shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" def train_hypernetwork(*args): - initial_hypernetwork = shared.hypernetwork + initial_hypernetwork = shared.loaded_hypernetwork try: sd_hijack.undo_optimizations() @@ -38,6 +38,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception: raise finally: - shared.hypernetwork = initial_hypernetwork + shared.loaded_hypernetwork = initial_hypernetwork sd_hijack.apply_optimizations() diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 25cb67a4..27e571fc 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -8,7 +8,8 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, hypernetwork +from modules import shared +from modules.hypernetwork import hypernetwork if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: diff --git a/modules/shared.py b/modules/shared.py index 14b40d70..8753015e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,8 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, hypernetwork +from modules import sd_samplers +from modules.hypernetwork import hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") @@ -82,10 +84,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file -hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) +hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None +def reload_hypernetworks(): + global hypernetworks + + hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + + class State: skipped = False interrupted = False diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5965c5a0..d6977950 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -238,12 +238,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') + preview_text = text if preview_image_prompt == "" else preview_image_prompt + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, - prompt=text, + prompt=preview_text, steps=20, - height=training_height, - width=training_width, + height=training_height, + width=training_width, do_not_save_grid=True, do_not_save_samples=True, ) @@ -254,7 +256,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.current_image = image image.save(last_saved_image) - last_saved_image += f", prompt: {text}" + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step diff --git a/modules/ui.py b/modules/ui.py index 10b1ee3a..df653059 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1023,7 +1023,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_embedding = gr.Button(value="Create", variant='primary') + create_embedding = gr.Button(value="Create embedding", variant='primary') with gr.Group(): gr.HTML(value="

Create a new hypernetwork

") @@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_hypernetwork = gr.Button(value="Create", variant='primary') + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') with gr.Group(): gr.HTML(value="

Preprocess images

") @@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call): create_image_every, save_embedding_every, template_file, + preview_image_prompt, ], outputs=[ ti_output, diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 42e1489c..0af5993c 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,8 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images, hypernetwork +from modules import images +from modules.hypernetwork import hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared diff --git a/webui.py b/webui.py index 7c200551..ba2156c8 100644 --- a/webui.py +++ b/webui.py @@ -29,6 +29,7 @@ from modules import devices from modules import modelloader from modules.paths import script_path from modules.shared import cmd_opts +import modules.hypernetwork.hypernetwork modelloader.cleanup_models() modules.sd_models.setup_model() @@ -77,22 +78,12 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) -def set_hypernetwork(): - shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) - - -shared.reload_hypernetworks() -shared.opts.onchange("sd_hypernetwork", set_hypernetwork) -set_hypernetwork() - - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) -loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) -shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) +shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) def webui(): @@ -117,7 +108,7 @@ def webui(): prevent_thread_lock=True ) - app.add_middleware(GZipMiddleware,minimum_size=1000) + app.add_middleware(GZipMiddleware, minimum_size=1000) while 1: time.sleep(0.5) -- cgit v1.2.3 From 873efeed49bb5197a42da18272115b326c5d68f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 15:51:22 +0300 Subject: rename hypernetwork dir to hypernetworks to prevent clash with an old filename that people who use zip instead of git clone will have --- modules/hypernetwork/hypernetwork.py | 283 ---------------------------------- modules/hypernetwork/ui.py | 43 ------ modules/hypernetworks/hypernetwork.py | 283 ++++++++++++++++++++++++++++++++++ modules/hypernetworks/ui.py | 43 ++++++ modules/sd_hijack.py | 2 +- modules/sd_hijack_optimizations.py | 2 +- modules/shared.py | 2 +- modules/ui.py | 2 +- scripts/xy_grid.py | 2 +- webui.py | 2 +- 10 files changed, 332 insertions(+), 332 deletions(-) delete mode 100644 modules/hypernetwork/hypernetwork.py delete mode 100644 modules/hypernetwork/ui.py create mode 100644 modules/hypernetworks/hypernetwork.py create mode 100644 modules/hypernetworks/ui.py (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py deleted file mode 100644 index aa701bda..00000000 --- a/modules/hypernetwork/hypernetwork.py +++ /dev/null @@ -1,283 +0,0 @@ -import datetime -import glob -import html -import os -import sys -import traceback -import tqdm - -import torch - -from ldm.util import default -from modules import devices, shared, processing, sd_models -import torch -from torch import einsum -from einops import rearrange, repeat -import modules.textual_inversion.dataset - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, state_dict=None): - super().__init__() - - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) - - if state_dict is not None: - self.load_state_dict(state_dict, strict=True) - else: - - self.linear1.weight.data.normal_(mean=0.0, std=0.01) - self.linear1.bias.data.zero_() - self.linear2.weight.data.normal_(mean=0.0, std=0.01) - self.linear2.bias.data.zero_() - - self.to(devices.device) - - def forward(self, x): - return x + (self.linear2(self.linear1(x))) - - -class Hypernetwork: - filename = None - name = None - - def __init__(self, name=None): - self.filename = None - self.name = name - self.layers = {} - self.step = 0 - self.sd_checkpoint = None - self.sd_checkpoint_name = None - - for size in [320, 640, 768, 1280]: - self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) - - def weights(self): - res = [] - - for k, layers in self.layers.items(): - for layer in layers: - layer.train() - res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] - - return res - - def save(self, filename): - state_dict = {} - - for k, v in self.layers.items(): - state_dict[k] = (v[0].state_dict(), v[1].state_dict()) - - state_dict['step'] = self.step - state_dict['name'] = self.name - state_dict['sd_checkpoint'] = self.sd_checkpoint - state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name - - torch.save(state_dict, filename) - - def load(self, filename): - self.filename = filename - if self.name is None: - self.name = os.path.splitext(os.path.basename(filename))[0] - - state_dict = torch.load(filename, map_location='cpu') - - for size, sd in state_dict.items(): - if type(size) == int: - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) - - self.name = state_dict.get('name', self.name) - self.step = state_dict.get('step', 0) - self.sd_checkpoint = state_dict.get('sd_checkpoint', None) - self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) - - -def list_hypernetworks(path): - res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): - name = os.path.splitext(os.path.basename(filename))[0] - res[name] = filename - return res - - -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - if path is not None: - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) - - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print(f"Unloading hypernetwork") - - shared.loaded_hypernetwork = None - - -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - if layer is not None: - layer.hyper_k = hypernetwork_layers[0] - layer.hyper_v = hypernetwork_layers[1] - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - - -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) - k = self.to_k(context_k) - v = self.to_v(context_v) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if mask is not None: - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) - - -def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): - assert hypernetwork_name, 'embedding not selected' - - path = shared.hypernetworks.get(hypernetwork_name, None) - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) - - shared.state.textinfo = "Initializing hypernetwork training..." - shared.state.job_count = steps - - filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') - - log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) - - if save_hypernetwork_every > 0: - hypernetwork_dir = os.path.join(log_directory, "hypernetworks") - os.makedirs(hypernetwork_dir, exist_ok=True) - else: - hypernetwork_dir = None - - if create_image_every > 0: - images_dir = os.path.join(log_directory, "images") - os.makedirs(images_dir, exist_ok=True) - else: - images_dir = None - - cond_model = shared.sd_model.cond_stage_model - - shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) - - hypernetwork = shared.loaded_hypernetwork - weights = hypernetwork.weights() - for weight in weights: - weight.requires_grad = True - - optimizer = torch.optim.AdamW(weights, lr=learn_rate) - - losses = torch.zeros((32,)) - - last_saved_file = "" - last_saved_image = "" - - ititial_step = hypernetwork.step or 0 - if ititial_step > steps: - return hypernetwork, filename - - pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text) in pbar: - hypernetwork.step = i + ititial_step - - if hypernetwork.step > steps: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = cond_model([text]) - - x = x.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] - del x - - losses[hypernetwork.step % losses.shape[0]] = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - pbar.set_description(f"loss: {losses.mean():.7f}") - - if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') - hypernetwork.save(last_saved_file) - - if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - - preview_text = text if preview_image_prompt == "" else preview_image_prompt - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - prompt=preview_text, - steps=20, - do_not_save_grid=True, - do_not_save_samples=True, - ) - - processed = processing.process_images(p) - image = processed.images[0] - - shared.state.current_image = image - image.save(last_saved_image) - - last_saved_image += f", prompt: {preview_text}" - - shared.state.job_no = hypernetwork.step - - shared.state.textinfo = f""" -

-Loss: {losses.mean():.7f}
-Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
-Last saved embedding: {html.escape(last_saved_file)}
-Last saved image: {html.escape(last_saved_image)}
-

-""" - - checkpoint = sd_models.select_checkpoint() - - hypernetwork.sd_checkpoint = checkpoint.hash - hypernetwork.sd_checkpoint_name = checkpoint.model_name - hypernetwork.save(filename) - - return hypernetwork, filename - - diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py deleted file mode 100644 index f6d1d0a3..00000000 --- a/modules/hypernetwork/ui.py +++ /dev/null @@ -1,43 +0,0 @@ -import html -import os - -import gradio as gr - -import modules.textual_inversion.textual_inversion -import modules.textual_inversion.preprocess -from modules import sd_hijack, shared -from modules.hypernetwork import hypernetwork - - -def create_hypernetwork(name): - fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" - - hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) - hypernet.save(fn) - - shared.reload_hypernetworks() - - return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" - - -def train_hypernetwork(*args): - - initial_hypernetwork = shared.loaded_hypernetwork - - try: - sd_hijack.undo_optimizations() - - hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) - - res = f""" -Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. -Hypernetwork saved to {html.escape(filename)} -""" - return res, "" - except Exception: - raise - finally: - shared.loaded_hypernetwork = initial_hypernetwork - sd_hijack.apply_optimizations() - diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py new file mode 100644 index 00000000..aa701bda --- /dev/null +++ b/modules/hypernetworks/hypernetwork.py @@ -0,0 +1,283 @@ +import datetime +import glob +import html +import os +import sys +import traceback +import tqdm + +import torch + +from ldm.util import default +from modules import devices, shared, processing, sd_models +import torch +from torch import einsum +from einops import rearrange, repeat +import modules.textual_inversion.dataset + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict=None): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + if state_dict is not None: + self.load_state_dict(state_dict, strict=True) + else: + + self.linear1.weight.data.normal_(mean=0.0, std=0.01) + self.linear1.bias.data.zero_() + self.linear2.weight.data.normal_(mean=0.0, std=0.01) + self.linear2.bias.data.zero_() + + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, name=None): + self.filename = None + self.name = name + self.layers = {} + self.step = 0 + self.sd_checkpoint = None + self.sd_checkpoint_name = None + + for size in [320, 640, 768, 1280]: + self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) + + def weights(self): + res = [] + + for k, layers in self.layers.items(): + for layer in layers: + layer.train() + res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] + + return res + + def save(self, filename): + state_dict = {} + + for k, v in self.layers.items(): + state_dict[k] = (v[0].state_dict(), v[1].state_dict()) + + state_dict['step'] = self.step + state_dict['name'] = self.name + state_dict['sd_checkpoint'] = self.sd_checkpoint + state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name + + torch.save(state_dict, filename) + + def load(self, filename): + self.filename = filename + if self.name is None: + self.name = os.path.splitext(os.path.basename(filename))[0] + + state_dict = torch.load(filename, map_location='cpu') + + for size, sd in state_dict.items(): + if type(size) == int: + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + self.name = state_dict.get('name', self.name) + self.step = state_dict.get('step', 0) + self.sd_checkpoint = state_dict.get('sd_checkpoint', None) + self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) + + +def list_hypernetworks(path): + res = {} + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res + + +def load_hypernetwork(filename): + path = shared.hypernetworks.get(filename, None) + if path is not None: + print(f"Loading hypernetwork {filename}") + try: + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + + except Exception: + print(f"Error loading hypernetwork {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + + shared.loaded_hypernetwork = None + + +def apply_hypernetwork(hypernetwork, context, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is None: + return context, context + + if layer is not None: + layer.hyper_k = hypernetwork_layers[0] + layer.hyper_v = hypernetwork_layers[1] + + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v + + +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) + k = self.to_k(context_k) + v = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): + assert hypernetwork_name, 'embedding not selected' + + path = shared.hypernetworks.get(hypernetwork_name, None) + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + + shared.state.textinfo = "Initializing hypernetwork training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + + if save_hypernetwork_every > 0: + hypernetwork_dir = os.path.join(log_directory, "hypernetworks") + os.makedirs(hypernetwork_dir, exist_ok=True) + else: + hypernetwork_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + cond_model = shared.sd_model.cond_stage_model + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + + hypernetwork = shared.loaded_hypernetwork + weights = hypernetwork.weights() + for weight in weights: + weight.requires_grad = True + + optimizer = torch.optim.AdamW(weights, lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = hypernetwork.step or 0 + if ititial_step > steps: + return hypernetwork, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) + for i, (x, text) in pbar: + hypernetwork.step = i + ititial_step + + if hypernetwork.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([text]) + + x = x.to(devices.device) + loss = shared.sd_model(x.unsqueeze(0), c)[0] + del x + + losses[hypernetwork.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') + hypernetwork.save(last_saved_file) + + if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + + preview_text = text if preview_image_prompt == "" else preview_image_prompt + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=preview_text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = hypernetwork.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {hypernetwork.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + checkpoint = sd_models.select_checkpoint() + + hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint_name = checkpoint.model_name + hypernetwork.save(filename) + + return hypernetwork, filename + + diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py new file mode 100644 index 00000000..811bc31e --- /dev/null +++ b/modules/hypernetworks/ui.py @@ -0,0 +1,43 @@ +import html +import os + +import gradio as gr + +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess +from modules import sd_hijack, shared +from modules.hypernetworks import hypernetwork + + +def create_hypernetwork(name): + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet.save(fn) + + shared.reload_hypernetworks() + + return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" + + +def train_hypernetwork(*args): + + initial_hypernetwork = shared.loaded_hypernetwork + + try: + sd_hijack.undo_optimizations() + + hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. +Hypernetwork saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + shared.loaded_hypernetwork = initial_hypernetwork + sd_hijack.apply_optimizations() + diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f873049a..f07ec041 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,7 +37,7 @@ def apply_optimizations(): def undo_optimizations(): - from modules.hypernetwork import hypernetwork + from modules.hypernetworks import hypernetwork ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 27e571fc..3349b9c3 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from ldm.util import default from einops import rearrange from modules import shared -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: diff --git a/modules/shared.py b/modules/shared.py index 375e3afb..1dc2ccf2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.sd_models import modules.styles import modules.devices as devices from modules import sd_samplers -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') diff --git a/modules/ui.py b/modules/ui.py index f57f32db..42e5d866 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -39,7 +39,7 @@ import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui -import modules.hypernetwork.ui +import modules.hypernetworks.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 16918c99..cddb192a 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,7 @@ import modules.scripts as scripts import gradio as gr from modules import images -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared diff --git a/webui.py b/webui.py index ba2156c8..faa38a0d 100644 --- a/webui.py +++ b/webui.py @@ -29,7 +29,7 @@ from modules import devices from modules import modelloader from modules.paths import script_path from modules.shared import cmd_opts -import modules.hypernetwork.hypernetwork +import modules.hypernetworks.hypernetwork modelloader.cleanup_models() modules.sd_models.setup_model() -- cgit v1.2.3 From 5ba23cb41f28f5856a7f64cb0d95e1e94dce90af Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 17:28:17 +0300 Subject: change default for XY plot's Y to Nothing. --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index cddb192a..ef431105 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -197,7 +197,7 @@ class Script(scripts.Script): x_values = gr.Textbox(label="X values", visible=False, lines=1) with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[4].label, visible=False, type="index", elem_id="y_type") + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="y_type") y_values = gr.Textbox(label="Y values", visible=False, lines=1) draw_legend = gr.Checkbox(label='Draw legend', value=True) -- cgit v1.2.3 From 2d006ce16cd95d587533656c3ac4991495e96f23 Mon Sep 17 00:00:00 2001 From: Milly Date: Mon, 10 Oct 2022 00:56:36 +0900 Subject: xy_grid: Find hypernetwork by closest name --- modules/hypernetworks/hypernetwork.py | 11 +++++++++++ scripts/xy_grid.py | 6 +++++- 2 files changed, 16 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 470659df..8f2192e2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -120,6 +120,17 @@ def load_hypernetwork(filename): shared.loaded_hypernetwork = None +def find_closest_hypernetwork_name(search: str): + if not search: + return None + search = search.lower() + applicable = [name for name in shared.hypernetworks if search in name.lower()] + if not applicable: + return None + applicable = sorted(applicable, key=lambda name: len(name)) + return applicable[0] + + def apply_hypernetwork(hypernetwork, context, layer=None): hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index ef431105..6f4217ec 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -84,7 +84,11 @@ def apply_checkpoint(p, x, xs): def apply_hypernetwork(p, x, xs): - hypernetwork.load_hypernetwork(x) + if x.lower() in ["", "none"]: + name = None + else: + name = hypernetwork.find_closest_hypernetwork_name(x) + hypernetwork.load_hypernetwork(name) def apply_clip_skip(p, x, xs): -- cgit v1.2.3 From 7dba1c07cb337114507d9c256f9b843162c187d6 Mon Sep 17 00:00:00 2001 From: Milly Date: Mon, 10 Oct 2022 01:37:09 +0900 Subject: xy_grid: Confirm that hypernetwork options are valid before starting --- scripts/xy_grid.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6f4217ec..b2239d0a 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -88,9 +88,19 @@ def apply_hypernetwork(p, x, xs): name = None else: name = hypernetwork.find_closest_hypernetwork_name(x) + if not name: + raise RuntimeError(f"Unknown hypernetwork: {x}") hypernetwork.load_hypernetwork(name) +def confirm_hypernetworks(xs): + for x in xs: + if x.lower() in ["", "none"]: + continue + if not hypernetwork.find_closest_hypernetwork_name(x): + raise RuntimeError(f"Unknown hypernetwork: {x}") + + def apply_clip_skip(p, x, xs): opts.data["CLIP_stop_at_last_layers"] = x @@ -284,6 +294,8 @@ class Script(scripts.Script): for ckpt_val in valslist: if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None: raise RuntimeError(f"Checkpoint for {ckpt_val} not found") + elif opt.label == "Hypernetwork": + confirm_hypernetworks(valslist) return valslist -- cgit v1.2.3 From 2fffd4bddce12b2c98a5bae5a2cc6d64450d65a0 Mon Sep 17 00:00:00 2001 From: Milly Date: Mon, 10 Oct 2022 02:20:35 +0900 Subject: xy_grid: Refactor confirm functions --- scripts/xy_grid.py | 73 +++++++++++++++++++++++++++++------------------------- 1 file changed, 39 insertions(+), 34 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index b2239d0a..3bb080bf 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -77,12 +77,26 @@ def apply_sampler(p, x, xs): p.sampler_index = sampler_index +def confirm_samplers(p, xs): + samplers_dict = build_samplers_dict(p) + for x in xs: + if x.lower() not in samplers_dict.keys(): + raise RuntimeError(f"Unknown sampler: {x}") + + def apply_checkpoint(p, x, xs): info = modules.sd_models.get_closet_checkpoint_match(x) - assert info is not None, f'Checkpoint for {x} not found' + if info is None: + raise RuntimeError(f"Unknown checkpoint: {x}") modules.sd_models.reload_model_weights(shared.sd_model, info) +def confirm_checkpoints(p, xs): + for x in xs: + if modules.sd_models.get_closet_checkpoint_match(x) is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + + def apply_hypernetwork(p, x, xs): if x.lower() in ["", "none"]: name = None @@ -93,7 +107,7 @@ def apply_hypernetwork(p, x, xs): hypernetwork.load_hypernetwork(name) -def confirm_hypernetworks(xs): +def confirm_hypernetworks(p, xs): for x in xs: if x.lower() in ["", "none"]: continue @@ -135,29 +149,29 @@ def str_permutations(x): return x -AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) -AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) +AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"]) +AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"]) axis_options = [ - AxisOption("Nothing", str, do_nothing, format_nothing), - AxisOption("Seed", int, apply_field("seed"), format_value_add_label), - AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label), - AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label), - AxisOption("Steps", int, apply_field("steps"), format_value_add_label), - AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), - AxisOption("Prompt S/R", str, apply_prompt, format_value), - AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), - AxisOption("Sampler", str, apply_sampler, format_value), - AxisOption("Checkpoint name", str, apply_checkpoint, format_value), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value), - AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), - AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), - AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), - AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), - AxisOption("Eta", float, apply_field("eta"), format_value_add_label), - AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label), - AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones + AxisOption("Nothing", str, do_nothing, format_nothing, None), + AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None), + AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None), + AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None), + AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None), + AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None), + AxisOption("Prompt S/R", str, apply_prompt, format_value, None), + AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None), + AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks), + AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None), + AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None), + AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None), + AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None), + AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), + AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), + AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -283,19 +297,10 @@ class Script(scripts.Script): valslist = list(permutations(valslist)) valslist = [opt.type(x) for x in valslist] - + # Confirm options are valid before starting - if opt.label == "Sampler": - samplers_dict = build_samplers_dict(p) - for sampler_val in valslist: - if sampler_val.lower() not in samplers_dict.keys(): - raise RuntimeError(f"Unknown sampler: {sampler_val}") - elif opt.label == "Checkpoint name": - for ckpt_val in valslist: - if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None: - raise RuntimeError(f"Checkpoint for {ckpt_val} not found") - elif opt.label == "Hypernetwork": - confirm_hypernetworks(valslist) + if opt.confirm: + opt.confirm(p, valslist) return valslist -- cgit v1.2.3 From 94c01aa35656130b56f401830ad443ce3d97c364 Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Tue, 11 Oct 2022 17:05:20 -0700 Subject: draw_xy_grid provides the option to also return lone images --- scripts/xy_grid.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 3bb080bf..14edacc1 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -175,8 +175,9 @@ axis_options = [ ] -def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): +def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images): res = [] + successful_images = [] ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] @@ -194,7 +195,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): first_processed = processed try: - res.append(processed.images[0]) + processed_image = processed.images[0] + res.append(processed_image) + successful_images.append(processed_image) except: res.append(Image.new(res[0].mode, res[0].size)) @@ -203,6 +206,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) first_processed.images = [grid] + if include_lone_images: + first_processed.images += successful_images return first_processed @@ -229,11 +234,12 @@ class Script(scripts.Script): y_values = gr.Textbox(label="Y values", visible=False, lines=1) draw_legend = gr.Checkbox(label='Draw legend', value=True) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=True) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False) - return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds] + return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] - def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds): + def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -344,7 +350,8 @@ class Script(scripts.Script): x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], cell=cell, - draw_legend=draw_legend + draw_legend=draw_legend, + include_lone_images=include_lone_images ) if opts.grid_save: -- cgit v1.2.3 From 8711c2fe0135d5c160a57db41cb79ed1942ce7fa Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Wed, 12 Oct 2022 16:12:12 -0700 Subject: Fix metadata contents --- modules/ui.py | 5 +---- scripts/xy_grid.py | 52 ++++++++++++++++++++++++++++++++++------------------ 2 files changed, 35 insertions(+), 22 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/ui.py b/modules/ui.py index 4fa405a9..e07ee0e1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -148,10 +148,7 @@ def save_files(js_data, images, do_make_zip, index): is_grid = image_index < p.index_of_first_image i = 0 if is_grid else (image_index - p.index_of_first_image) - seed = p.all_seeds[i] if len(p.all_seeds) > 1 else p.seed - prompt = p.all_prompts[i] if len(p.all_prompts) > 1 else p.prompt - info = p.infotexts[image_index] if len(p.infotexts) > 1 else p.infotexts[0] - fullfn, txt_fullfn = save_image(image, path, "", seed=seed, prompt=prompt, extension=extension, info=info, grid=is_grid, p=p, save_to_dirs=save_to_dirs) + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) filenames.append(filename) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 14edacc1..02931ae6 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -176,13 +176,16 @@ axis_options = [ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images): - res = [] - successful_images = [] - ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] - first_processed = None + # Temporary list of all the images that are generated to be populated into the grid. + # Will be filled with empty images for any individual step that fails to process properly + image_cache = [] + + processed_result = None + cell_mode = "P" + cell_size = (1,1) state.job_count = len(xs) * len(ys) * p.n_iter @@ -190,26 +193,39 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ for ix, x in enumerate(xs): state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" - processed = cell(x, y) - if first_processed is None: - first_processed = processed - + processed:Processed = cell(x, y) try: - processed_image = processed.images[0] - res.append(processed_image) - successful_images.append(processed_image) + # this dereference will throw an exception if the image was not processed + # (this happens in cases such as if the user stops the process from the UI) + processed_image = processed.images[0] + + if processed_result is None: + # Use our first valid processed result as a template container to hold our full results + processed_result = copy(processed) + cell_mode = processed_image.mode + cell_size = processed_image.size + processed_result.images = [Image.new(cell_mode, cell_size)] + + image_cache.append(processed_image) + if include_lone_images: + processed_result.images.append(processed_image) + processed_result.all_prompts.append(processed.prompt) + processed_result.all_seeds.append(processed.seed) + processed_result.infotexts.append(processed.infotexts[0]) except: - res.append(Image.new(res[0].mode, res[0].size)) + image_cache.append(Image.new(cell_mode, cell_size)) + + if not processed_result: + print("Unexpected error: draw_xy_grid failed to return even a single processed image") + return Processed() - grid = images.image_grid(res, rows=len(ys)) + grid = images.image_grid(image_cache, rows=len(ys)) if draw_legend: - grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) + grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) - first_processed.images = [grid] - if include_lone_images: - first_processed.images += successful_images + processed_result.images[0] = grid - return first_processed + return processed_result re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -- cgit v1.2.3 From 354ef0da3b1f0fa5c113d04b6c79e3908c848d23 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 20:12:37 +0300 Subject: add hypernetwork multipliers --- modules/hypernetworks/hypernetwork.py | 8 +++++++- modules/shared.py | 5 ++++- modules/ui.py | 5 ++++- scripts/xy_grid.py | 9 ++++++++- style.css | 3 +++ webui.py | 2 +- 6 files changed, 27 insertions(+), 5 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b6c06d49..f1248bb7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): + multiplier = 1.0 + def __init__(self, dim, state_dict=None): super().__init__() @@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module): self.to(devices.device) def forward(self, x): - return x + (self.linear2(self.linear1(x))) + return x + (self.linear2(self.linear1(x))) * self.multiplier + + +def apply_strength(value=None): + HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength class Hypernetwork: diff --git a/modules/shared.py b/modules/shared.py index d8e3a286..5901e605 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -238,7 +238,8 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), - "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), @@ -348,6 +349,8 @@ class Options: item = self.data_labels.get(key) item.onchange = func + func() + def dumpjson(self): d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} return json.dumps(d) diff --git a/modules/ui.py b/modules/ui.py index 0a58f6be..673014f2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1244,7 +1244,10 @@ def create_ui(wrap_gradio_gpu_call): def refresh(): info.refresh() refreshed_args = info.component_args() if callable(info.component_args) else info.component_args - res.choices = refreshed_args["choices"] + + for k, v in refreshed_args.items(): + setattr(res, k, v) + return gr.update(**(refreshed_args or {})) refresh_button.click( diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 02931ae6..efb63af5 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -107,6 +107,10 @@ def apply_hypernetwork(p, x, xs): hypernetwork.load_hypernetwork(name) +def apply_hypernetwork_strength(p, x, xs): + hypernetwork.apply_strength(x) + + def confirm_hypernetworks(p, xs): for x in xs: if x.lower() in ["", "none"]: @@ -165,6 +169,7 @@ axis_options = [ AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers), AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints), AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks), + AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None), @@ -250,7 +255,7 @@ class Script(scripts.Script): y_values = gr.Textbox(label="Y values", visible=False, lines=1) draw_legend = gr.Checkbox(label='Draw legend', value=True) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=True) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False) return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] @@ -377,6 +382,8 @@ class Script(scripts.Script): modules.sd_models.reload_model_weights(shared.sd_model) hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + hypernetwork.apply_strength() + opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers diff --git a/style.css b/style.css index ad2a52cc..aa3d379c 100644 --- a/style.css +++ b/style.css @@ -522,6 +522,9 @@ canvas[key="mask"] { z-index: 200; width: 8em; } +#quicksettings .gr-box > div > div > input.gr-text-input { + top: -1.12em; +} .row.gr-compact{ overflow: visible; diff --git a/webui.py b/webui.py index 33ba7905..fe0ce321 100644 --- a/webui.py +++ b/webui.py @@ -72,7 +72,6 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) - def initialize(): modelloader.cleanup_models() modules.sd_models.setup_model() @@ -86,6 +85,7 @@ def initialize(): shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) + shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) def webui(): -- cgit v1.2.3 From 02382f7ce462a360e8aea9ee3178da48b564f70a Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Wed, 12 Oct 2022 16:35:36 -0700 Subject: regression in xy_grid Var. seed fixing --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index efb63af5..5700b007 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -338,7 +338,7 @@ class Script(scripts.Script): ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label == 'Seed': + if axis_opt.label in ['Seed','Var. seed']: return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] else: return axis_list -- cgit v1.2.3 From 4cc37e4cdf2ce5f5b753786b55ae1d4abd530c01 Mon Sep 17 00:00:00 2001 From: Naeaeaeaeae Date: Thu, 13 Oct 2022 18:49:58 +0200 Subject: [xy_grid.py] add option denoising_strength --- scripts/xy_grid.py | 1 + 1 file changed, 1 insertion(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 5700b007..fda2b71d 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -176,6 +176,7 @@ axis_options = [ AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), + AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] -- cgit v1.2.3 From 989a552de3d1fcd1f178fe873713b884e192dd61 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 22:04:08 +0300 Subject: remove the other Denoising --- scripts/xy_grid.py | 1 - 1 file changed, 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index fda2b71d..8c7da6bb 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -177,7 +177,6 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] -- cgit v1.2.3 From a8f7722e4e7460122b44589c3718eee0c597009d Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 14 Oct 2022 14:26:38 -0700 Subject: Fix XY-plot steps if highres fix is enabled --- scripts/xy_grid.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8c7da6bb..88ad3bf7 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -12,7 +12,7 @@ import gradio as gr from modules import images from modules.hypernetworks import hypernetwork -from modules.processing import process_images, Processed, get_correct_sampler +from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.sd_samplers @@ -354,6 +354,9 @@ class Script(scripts.Script): else: total_steps = p.steps * len(xs) * len(ys) + if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: + total_steps *= 2 + print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") shared.total_tqdm.updateTotal(total_steps * p.n_iter) -- cgit v1.2.3 From cccc5a20fce4bde9a4299f8790366790735f1d05 Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Sun, 16 Oct 2022 12:10:07 -0700 Subject: Safeguard setting restore logic against exceptions also useful for keeping settings cache and restore logic together, and nice for code reuse (other third party scripts can import this class) --- scripts/xy_grid.py | 48 ++++++++++++++++++++++++++---------------------- 1 file changed, 26 insertions(+), 22 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 88ad3bf7..5cca168a 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -233,6 +233,21 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ return processed_result +class SharedSettingsStackHelper(object): + def __enter__(self): + self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers + self.hypernetwork = opts.sd_hypernetwork + self.model = shared.sd_model + + def __exit__(self, exc_type, exc_value, tb): + modules.sd_models.reload_model_weights(self.model) + + hypernetwork.load_hypernetwork(self.hypernetwork) + hypernetwork.apply_strength() + + opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers + + re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") @@ -267,9 +282,6 @@ class Script(scripts.Script): if not opts.return_grid: p.batch_size = 1 - - CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers - def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -367,27 +379,19 @@ class Script(scripts.Script): return process_images(pc) - processed = draw_xy_grid( - p, - xs=xs, - ys=ys, - x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], - y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], - cell=cell, - draw_legend=draw_legend, - include_lone_images=include_lone_images - ) + with SharedSettingsStackHelper(): + processed = draw_xy_grid( + p, + xs=xs, + ys=ys, + x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], + y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], + cell=cell, + draw_legend=draw_legend, + include_lone_images=include_lone_images + ) if opts.grid_save: images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) - # restore checkpoint in case it was changed by axes - modules.sd_models.reload_model_weights(shared.sd_model) - - hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - hypernetwork.apply_strength() - - - opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers - return processed -- cgit v1.2.3 From 49533eed9e3aad19e9868ee140708baec4fd44be Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 16:01:27 -0700 Subject: XY grid correctly re-assignes model when config changes --- modules/sd_models.py | 6 +++--- scripts/xy_grid.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072db08..fea84630 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -204,9 +204,9 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack - checkpoint_info = select_checkpoint() + checkpoint_info = checkpoint_info or select_checkpoint() if checkpoint_info.config != shared.cmd_opts.config: print(f"Loading config from: {checkpoint_info.config}") @@ -249,7 +249,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model() + shared.sd_model = load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 5cca168a..eff0c942 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs): if info is None: raise RuntimeError(f"Unknown checkpoint: {x}") modules.sd_models.reload_model_weights(shared.sd_model, info) + p.sd_model = shared.sd_model def confirm_checkpoints(p, xs): -- cgit v1.2.3 From 605d27687f433c0fefb9025aace7dc94f0ebd454 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Tue, 25 Oct 2022 12:20:54 -0700 Subject: Added conditioning image masking to xy_grid. Use `True` and `False` to select values. --- modules/processing.py | 2 +- scripts/xy_grid.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/processing.py b/modules/processing.py index 96f56b0d..23ee5e02 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -770,7 +770,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): conditioning_mask = conditioning_mask.to(image.device) conditioning_image = image - if shared.opts.inpainting_mask_image: + if getattr(self, "inpainting_mask_image", shared.opts.inpainting_mask_image): conditioning_image = conditioning_image * (1.0 - conditioning_mask) conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index eff0c942..0843adcc 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -153,6 +153,8 @@ def str_permutations(x): """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" return x +def str_to_bool(x): + return "true" in x.lower().strip() AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"]) @@ -178,6 +180,7 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), + AxisOption("Mask Conditioning Image", str_to_bool, apply_field("inpainting_mask_image"), format_value_add_label, None), ] -- cgit v1.2.3 From 8b4f32779f28010fc8077e8fcfb85a3205b36bc2 Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Tue, 25 Oct 2022 13:15:08 -0700 Subject: Switch to a continous blend for cond. image. --- modules/processing.py | 9 ++++++--- modules/shared.py | 2 +- scripts/xy_grid.py | 5 +---- 3 files changed, 8 insertions(+), 8 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/processing.py b/modules/processing.py index 23ee5e02..02292bdc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -769,9 +769,12 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # Create another latent image, this time with a masked version of the original input. conditioning_mask = conditioning_mask.to(image.device) - conditioning_image = image - if getattr(self, "inpainting_mask_image", shared.opts.inpainting_mask_image): - conditioning_image = conditioning_image * (1.0 - conditioning_mask) + # Smoothly interpolate between the masked and unmasked latent conditioning image. + conditioning_image = torch.lerp( + image, + image * (1.0 - conditioning_mask), + getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) + ) conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) diff --git a/modules/shared.py b/modules/shared.py index 1d0ff1a1..e0ffb824 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -320,7 +320,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), - "inpainting_mask_image": OptionInfo(True, "Mask original image for conditioning used by inpainting model."), + "inpainting_mask_weight": OptionInfo(1.0, "Blend betweeen an unmasked and masked conditioning image for inpainting models.", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 0843adcc..f5255786 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -153,9 +153,6 @@ def str_permutations(x): """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" return x -def str_to_bool(x): - return "true" in x.lower().strip() - AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"]) @@ -180,7 +177,7 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOption("Mask Conditioning Image", str_to_bool, apply_field("inpainting_mask_image"), format_value_add_label, None), + AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), ] -- cgit v1.2.3 From 4dd898b8c15e342f817d3fb1c8dc9f2d5d111022 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 08:38:11 +0300 Subject: do not mess with components' visibility for scripts; instead create group components and show/hide those; this will break scripts that create invisible components and rely on UI but the earlier i make this change the better --- modules/scripts.py | 34 ++++++++++++++++++---------------- scripts/custom_code.py | 2 +- scripts/outpainting_mk_2.py | 2 +- scripts/poor_mans_outpainting.py | 4 ++-- scripts/prompts_from_file.py | 10 +++++----- scripts/sd_upscale.py | 4 ++-- scripts/xy_grid.py | 8 ++++---- 7 files changed, 33 insertions(+), 31 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/scripts.py b/modules/scripts.py index 533db45c..28ce07f4 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -18,6 +18,9 @@ class Script: args_to = None alwayson = False + """A gr.Group component that has all script's UI inside it""" + group = None + infotext_fields = None """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example @@ -218,8 +221,6 @@ class ScriptRunner: for control in controls: control.custom_script_source = os.path.basename(script.filename) - if not script.alwayson: - control.visible = False if script.infotext_fields is not None: self.infotext_fields += script.infotext_fields @@ -229,40 +230,41 @@ class ScriptRunner: script.args_to = len(inputs) for script in self.alwayson_scripts: - with gr.Group(): + with gr.Group() as group: create_script_ui(script, inputs, inputs_alwayson) + script.group = group + dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") dropdown.save_to_config = True inputs[0] = dropdown for script in self.selectable_scripts: - create_script_ui(script, inputs, inputs_alwayson) + with gr.Group(visible=False) as group: + create_script_ui(script, inputs, inputs_alwayson) + + script.group = group def select_script(script_index): - if 0 < script_index <= len(self.selectable_scripts): - script = self.selectable_scripts[script_index-1] - args_from = script.args_from - args_to = script.args_to - else: - args_from = 0 - args_to = 0 + selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None - return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)] + return [gr.update(visible=selected_script == s) for s in self.selectable_scripts] def init_field(title): + """called when an initial value is set from ui-config.json to show script's UI components""" + if title == 'None': return + script_index = self.titles.index(title) - script = self.selectable_scripts[script_index] - for i in range(script.args_from, script.args_to): - inputs[i].visible = True + self.selectable_scripts[script_index].group.visible = True dropdown.init_field = init_field + dropdown.change( fn=select_script, inputs=[dropdown], - outputs=inputs + outputs=[script.group for script in self.selectable_scripts] ) return inputs diff --git a/scripts/custom_code.py b/scripts/custom_code.py index a9b10c09..22e7b77a 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -14,7 +14,7 @@ class Script(scripts.Script): return cmd_opts.allow_code def ui(self, is_img2img): - code = gr.Textbox(label="Python code", visible=False, lines=1) + code = gr.Textbox(label="Python code", lines=1) return [code] diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 2afd4aa5..cf71cb92 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -132,7 +132,7 @@ class Script(scripts.Script): info = gr.HTML("

Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8

") pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, visible=False) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05) diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index b0469110..ea45beb0 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -22,8 +22,8 @@ class Script(scripts.Script): return None pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False) - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index") direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) return [pixels, mask_blur, inpainting_fill, direction] diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index d187cd9c..3388bc77 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -83,13 +83,14 @@ def cmdargs(line): def load_prompt_file(file): - if (file is None): + if file is None: lines = [] else: lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")] return None, "\n".join(lines), gr.update(lines=7) + class Script(scripts.Script): def title(self): return "Prompts from file or textbox" @@ -107,9 +108,9 @@ class Script(scripts.Script): # We don't shrink back to 1, because that causes the control to ignore [enter], and it may # be unclear to the user that shift-enter is needed. prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt]) - return [checkbox_iterate, checkbox_iterate_batch, file, prompt_txt] + return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] - def run(self, p, checkbox_iterate, checkbox_iterate_batch, file, prompt_txt: str): + def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): lines = [x.strip() for x in prompt_txt.splitlines()] lines = [x for x in lines if len(x) > 0] @@ -157,5 +158,4 @@ class Script(scripts.Script): if checkbox_iterate: p.seed = p.seed + (p.batch_size * p.n_iter) - - return Processed(p, images, p.seed, "") \ No newline at end of file + return Processed(p, images, p.seed, "") diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index cb37ff7e..01074291 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -18,8 +18,8 @@ class Script(scripts.Script): def ui(self, is_img2img): info = gr.HTML("

Will upscale the image to twice the dimensions; use width and height sliders to set tile size

") - overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False) - upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False) + overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64) + upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") return [info, overlap, upscaler_index] diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index f5255786..417ed0d4 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -263,12 +263,12 @@ class Script(scripts.Script): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, visible=False, type="index", elem_id="x_type") - x_values = gr.Textbox(label="X values", visible=False, lines=1) + x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id="x_type") + x_values = gr.Textbox(label="X values", lines=1) with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="y_type") - y_values = gr.Textbox(label="Y values", visible=False, lines=1) + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id="y_type") + y_values = gr.Textbox(label="Y values", lines=1) draw_legend = gr.Checkbox(label='Draw legend', value=True) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False) -- cgit v1.2.3 From 9cc48fee4859908deefbb917b9521dc8aa43a89e Mon Sep 17 00:00:00 2001 From: byzod Date: Sun, 6 Nov 2022 10:15:02 +0800 Subject: fix scripts ignores file format settings for grids --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 417ed0d4..45a78db2 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -393,6 +393,6 @@ class Script(scripts.Script): ) if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) + images.save_image(processed.images[0], p.outpath_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) return processed -- cgit v1.2.3 From cdc8020d13c5eef099c609b0a911ccf3568afc0d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 12:01:51 +0300 Subject: change StableDiffusionProcessing to internally use sampler name instead of sampler index --- modules/api/api.py | 26 ++++++++--------------- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/images.py | 2 +- modules/img2img.py | 4 ++-- modules/processing.py | 29 +++++++++++--------------- modules/sd_samplers.py | 13 +++++++++--- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/txt2img.py | 3 ++- modules/ui.py | 2 +- scripts/img2imgalt.py | 4 ++-- scripts/xy_grid.py | 12 +++++------ 11 files changed, 49 insertions(+), 54 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/api/api.py b/modules/api/api.py index 596a6616..0eccccbb 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -6,9 +6,9 @@ from threading import Lock from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from fastapi import APIRouter, Depends, FastAPI, HTTPException import modules.shared as shared +from modules import sd_samplers from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo from PIL import PngImagePlugin from modules.sd_models import checkpoints_list @@ -25,8 +25,12 @@ def upscaler_to_index(name: str): raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") -sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) +def validate_sampler_name(name): + config = sd_samplers.all_samplers_map.get(name, None) + if config is None: + raise HTTPException(status_code=404, detail="Sampler not found") + return name def setUpscalers(req: dict): reqDict = vars(req) @@ -82,14 +86,9 @@ class Api: self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - sampler_index = sampler_to_index(txt2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - populate = txt2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True } @@ -109,12 +108,6 @@ class Api: return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): - sampler_index = sampler_to_index(img2imgreq.sampler_index) - - if sampler_index is None: - raise HTTPException(status_code=404, detail="Sampler not found") - - init_images = img2imgreq.init_images if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") @@ -123,10 +116,9 @@ class Api: if mask: mask = decode_base64_to_image(mask) - populate = img2imgreq.copy(update={ # Override __init__ params "sd_model": shared.sd_model, - "sampler_index": sampler_index[0], + "sampler_name": validate_sampler_name(img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, "mask": mask @@ -272,7 +264,7 @@ class Api: return vars(shared.cmd_opts) def get_samplers(self): - return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers] + return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): upscalers = [] diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7f182712..fbb87dd1 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared +from modules import devices, processing, sd_models, shared, sd_samplers from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -535,7 +535,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps - p.sampler_index = preview_sampler_index + p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width diff --git a/modules/images.py b/modules/images.py index ae705cbd..26d5b7a9 100644 --- a/modules/images.py +++ b/modules/images.py @@ -303,7 +303,7 @@ class FilenameGenerator: 'width': lambda self: self.image.width, 'height': lambda self: self.image.height, 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), - 'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), + 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime], [datetime