From e615d4f9d101e2712c7c2d0e3e8feb19cb430c74 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 2 Oct 2022 21:08:23 +0200 Subject: Convert folder icon surrogate pair to valid utf8 --- javascript/hints.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index 84694eeb..e72e9338 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -15,7 +15,7 @@ titles = { "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u{1f3a8}": "Add a random artist to the prompt.", "\u2199\ufe0f": "Read generation parameters from prompt into user interface.", - "\uD83D\uDCC2": "Open images output directory", + "\u{1f4c2}": "Open images output directory", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", -- cgit v1.2.3 From 556c36b9607e3f4eacdddc85f8e7a78b29476ea7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 09:18:00 +0300 Subject: add hint, refactor code for #1607 --- javascript/hints.js | 1 + scripts/xy_grid.py | 35 ++++++++++++++++++----------------- 2 files changed, 19 insertions(+), 17 deletions(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index e72e9338..8adcd983 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -47,6 +47,7 @@ titles = { "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", + "Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order", "Tiling": "Produce an image that can be tiled.", "Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.", diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 7def47f5..1237e754 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -29,10 +29,11 @@ def apply_prompt(p, x, xs): p.prompt = p.prompt.replace(xs[0], x) p.negative_prompt = p.negative_prompt.replace(xs[0], x) + def apply_order(p, x, xs): token_order = [] - # Initally grab the tokens from the prompt so they can be be replaced in order of earliest seen + # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen for token in x: token_order.append((p.prompt.find(token), token)) @@ -85,17 +86,26 @@ def format_value_add_label(p, opt, x): def format_value(p, opt, x): if type(x) == float: x = round(x, 8) - if type(x) == type(list()): - x = str(x) return x + +def format_value_join_list(p, opt, x): + return ", ".join(x) + + def do_nothing(p, x, xs): pass + def format_nothing(p, opt, x): return "" +def str_permutations(x): + """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" + return x + + AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"]) @@ -108,6 +118,7 @@ axis_options = [ AxisOption("Steps", int, apply_field("steps"), format_value_add_label), AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("Prompt S/R", str, apply_prompt, format_value), + AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), @@ -115,7 +126,6 @@ axis_options = [ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label), - AxisOption("Prompt order", type(list()), apply_order, format_value), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -158,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") + class Script(scripts.Script): def title(self): return "X/Y plot" @@ -186,11 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing': return [0] - if opt.type == type(list()): - valslist = [x for x in vals] - else: - valslist = [x.strip() for x in vals.split(",")] - + valslist = [x.strip() for x in vals.split(",")] if opt.type == int: valslist_ext = [] @@ -237,23 +244,17 @@ class Script(scripts.Script): valslist_ext.append(val) valslist = valslist_ext + elif opt.type == str_permutations: + valslist = list(permutations(valslist)) valslist = [opt.type(x) for x in valslist] return valslist x_opt = axis_options[x_type] - - if x_opt.label == "Prompt order": - x_values = list(permutations([x.strip() for x in x_values.split(",")])) - xs = process_axis(x_opt, x_values) y_opt = axis_options[y_type] - - if y_opt.label == "Prompt order": - y_values = list(permutations([y.strip() for y in y_values.split(",")])) - ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): -- cgit v1.2.3 From 786d9f63aaa4515df82eb2cf357ea92f3dae1e29 Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Tue, 4 Oct 2022 22:56:30 -0500 Subject: Add button to skip the current iteration --- javascript/hints.js | 1 + javascript/progressbar.js | 20 ++++++++++++++------ modules/img2img.py | 4 ++++ modules/processing.py | 4 ++++ modules/shared.py | 5 +++++ modules/ui.py | 8 ++++++++ style.css | 14 ++++++++++++-- webui.py | 1 + 8 files changed, 49 insertions(+), 8 deletions(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index 8adcd983..8e352e94 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -35,6 +35,7 @@ titles = { "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.", "Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.", + "Skip": "Stop processing current image and continue processing.", "Interrupt": "Stop processing images and return any results accumulated so far.", "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.", diff --git a/javascript/progressbar.js b/javascript/progressbar.js index f9e9290e..4395a215 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,8 +1,9 @@ // code related to showing and updating progressbar shown as the image is being made global_progressbars = {} -function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){ +function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ var progressbar = gradioApp().getElementById(id_progressbar) + var skip = id_skip ? gradioApp().getElementById(id_skip) : null var interrupt = gradioApp().getElementById(id_interrupt) if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ @@ -32,30 +33,37 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; if(!progressDiv){ + if (skip) { + skip.style.display = "none" + } interrupt.style.display = "none" } } - window.setTimeout(function(){ requestMoreProgress(id_part, id_progressbar_span, id_interrupt) }, 500) + window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500) }); mutationObserver.observe( progressbar, { childList:true, subtree:true }) } } onUiUpdate(function(){ - check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') - check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') - check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery') + check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') + check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') + check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery') }) -function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ +function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){ btn = gradioApp().getElementById(id_part+"_check_progress"); if(btn==null) return; btn.click(); var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; + var skip = id_skip ? gradioApp().getElementById(id_skip) : null var interrupt = gradioApp().getElementById(id_interrupt) if(progressDiv && interrupt){ + if (skip) { + skip.style.display = "block" + } interrupt.style.display = "block" } } diff --git a/modules/img2img.py b/modules/img2img.py index da212d72..e60b7e0f 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -32,6 +32,10 @@ def process_batch(p, input_dir, output_dir, args): for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" + if state.skipped: + state.skipped = False + state.interrupted = False + continue if state.interrupted: break diff --git a/modules/processing.py b/modules/processing.py index d814d5ac..6805039c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -355,6 +355,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: state.job_count = p.n_iter for n in range(p.n_iter): + if state.skipped: + state.skipped = False + state.interrupted = False + if state.interrupted: break diff --git a/modules/shared.py b/modules/shared.py index 864e772c..7f802bd9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -84,6 +84,7 @@ def selected_hypernetwork(): class State: + skipped = False interrupted = False job = "" job_no = 0 @@ -96,6 +97,10 @@ class State: current_image_sampling_step = 0 textinfo = None + def skip(self): + self.skipped = True + self.interrupted = True + def interrupt(self): self.interrupted = True diff --git a/modules/ui.py b/modules/ui.py index 4f18126f..e3e62fdd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -191,6 +191,7 @@ def wrap_gradio_call(func, extra_outputs=None): # last item is always HTML res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + shared.state.skipped = False shared.state.interrupted = False shared.state.job_count = 0 @@ -411,9 +412,16 @@ def create_toprow(is_img2img): with gr.Column(scale=1): with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + interrupt.click( fn=lambda: shared.state.interrupt(), inputs=[], diff --git a/style.css b/style.css index 50c5e557..6904fc50 100644 --- a/style.css +++ b/style.css @@ -393,10 +393,20 @@ input[type="range"]{ #txt2img_interrupt, #img2img_interrupt{ position: absolute; - width: 100%; + width: 50%; height: 72px; background: #b4c0cc; - border-radius: 8px; + border-radius: 0px; + display: none; +} + +#txt2img_skip, #img2img_skip{ + position: absolute; + width: 50%; + right: 0px; + height: 72px; + background: #b4c0cc; + border-radius: 0px; display: none; } diff --git a/webui.py b/webui.py index 480360fe..3b4cf5e9 100644 --- a/webui.py +++ b/webui.py @@ -58,6 +58,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): shared.state.current_latent = None shared.state.current_image = None shared.state.current_image_sampling_step = 0 + shared.state.skipped = False shared.state.interrupted = False shared.state.textinfo = None -- cgit v1.2.3 From 39919c40dd18f5a14ae21403efea1b0f819756c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 20:32:37 +0300 Subject: add eta noise seed delta option --- javascript/hints.js | 1 + modules/processing.py | 6 +++++- modules/shared.py | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index 8e352e94..47b80776 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -79,6 +79,7 @@ titles = { "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", + "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", } diff --git a/modules/processing.py b/modules/processing.py index 50ba4fc5..698b3069 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -207,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see # enables the generation of additional tensors with noise that the sampler will use during its processing. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds: + if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None @@ -247,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see if sampler_noises is not None: cnt = p.sampler.number_of_needed_noises(p) + if opts.eta_noise_seed_delta > 0: + torch.manual_seed(seed + opts.eta_noise_seed_delta) + for j in range(cnt): sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) @@ -301,6 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, + "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, } generation_params.update(p.extra_generation_params) diff --git a/modules/shared.py b/modules/shared.py index 5dfc344c..b1c65ecf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -260,6 +260,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -- cgit v1.2.3 From f98338faa84ecce503e68d8ba13d5f7bbae52730 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 23:15:48 +0300 Subject: add an option to not add watermark to created images --- javascript/hints.js | 1 + modules/shared.py | 1 + 2 files changed, 2 insertions(+) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index 47b80776..045f2d3c 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -80,6 +80,7 @@ titles = { "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", + "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be bevaing in an unethical manner.", } diff --git a/modules/shared.py b/modules/shared.py index da389f9c..ecd15ef5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -173,6 +173,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), + "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { -- cgit v1.2.3 From ca5efc316b9431746ff886d259275310f63f95fb Mon Sep 17 00:00:00 2001 From: LunixWasTaken Date: Tue, 11 Oct 2022 22:04:56 +0200 Subject: Typo fix in watermark hint. --- javascript/hints.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index 045f2d3c..b81c181b 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -80,7 +80,7 @@ titles = { "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", - "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be bevaing in an unethical manner.", + "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", } -- cgit v1.2.3 From c3c8eef9fd5a0c8b26319e32ca4a19b56204e6df Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 20:49:47 +0300 Subject: train: change filename processing to be more simple and configurable train: make it possible to make text files with prompts train: rework scheduler so that there's less repeating code in textual inversion and hypernets train: move epochs setting to options --- javascript/hints.js | 3 ++ modules/hypernetworks/hypernetwork.py | 40 +++++++++------------- modules/shared.py | 3 ++ modules/textual_inversion/dataset.py | 47 +++++++++++++++++++------- modules/textual_inversion/learn_schedule.py | 37 +++++++++++++++++++- modules/textual_inversion/textual_inversion.py | 35 +++++++------------ modules/ui.py | 2 -- 7 files changed, 105 insertions(+), 62 deletions(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index b81c181b..d51ee14c 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -81,6 +81,9 @@ titles = { "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", + + "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", + "Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.", } diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8314450a..b6c06d49 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -14,7 +14,7 @@ import torch from torch import einsum from einops import rearrange, repeat import modules.textual_inversion.dataset -from modules.textual_inversion.learn_schedule import LearnSchedule +from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): @@ -223,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps: return hypernetwork, filename - schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) - (learn_rate, end_step) = next(schedules) - print(f'Training at rate of {learn_rate} until step {end_step}') - - optimizer = torch.optim.AdamW(weights, lr=learn_rate) + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text, cond) in pbar: + for i, entry in pbar: hypernetwork.step = i + ititial_step - if hypernetwork.step > end_step: - try: - (learn_rate, end_step) = next(schedules) - except Exception: - break - tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') - for pg in optimizer.param_groups: - pg['lr'] = learn_rate + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break if shared.state.interrupted: break with torch.autocast("cuda"): - cond = cond.to(devices.device) - x = x.to(devices.device) + cond = entry.cond.to(devices.device) + x = entry.latent.to(devices.device) loss = shared.sd_model(x.unsqueeze(0), cond)[0] del x del cond @@ -267,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - preview_text = text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt optimizer.zero_grad() shared.sd_model.cond_stage_model.to(devices.device) @@ -282,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ) processed = processing.process_images(p) - image = processed.images[0] + image = processed.images[0] if len(processed.images)>0 else None if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - shared.state.current_image = image - image.save(last_saved_image) - - last_saved_image += f", prompt: {preview_text}" + if image is not None: + shared.state.current_image = image + image.save(last_saved_image) + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step @@ -299,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,

Loss: {losses.mean():.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/shared.py b/modules/shared.py index 42e99741..e64e69fc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -231,6 +231,9 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), + "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), + "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), + "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index f61f40d3..67e90afe 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -11,11 +11,21 @@ import tqdm from modules import devices, shared import re -re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") +re_numbers_at_start = re.compile(r"^[-\d]+\s*") + + +class DatasetEntry: + def __init__(self, filename=None, latent=None, filename_text=None): + self.filename = filename + self.latent = latent + self.filename_text = filename_text + self.cond = None + self.cond_text = None class PersonalizedBase(Dataset): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): + re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None self.placeholder_token = placeholder_token @@ -42,9 +52,18 @@ class PersonalizedBase(Dataset): except Exception: continue + text_filename = os.path.splitext(path)[0] + ".txt" filename = os.path.basename(path) - filename_tokens = os.path.splitext(filename)[0] - filename_tokens = re_tag.findall(filename_tokens) + + if os.path.exists(text_filename): + with open(text_filename, "r", encoding="utf8") as file: + filename_text = file.read() + else: + filename_text = os.path.splitext(filename)[0] + filename_text = re.sub(re_numbers_at_start, '', filename_text) + if re_word: + tokens = re_word.findall(filename_text) + filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) npimage = np.array(image).astype(np.uint8) npimage = (npimage / 127.5 - 1.0).astype(np.float32) @@ -55,13 +74,13 @@ class PersonalizedBase(Dataset): init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = init_latent.to(devices.cpu) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) + if include_cond: - text = self.create_text(filename_tokens) - cond = cond_model([text]).to(devices.cpu) - else: - cond = None + entry.cond_text = self.create_text(filename_text) + entry.cond = cond_model([entry.cond_text]).to(devices.cpu) - self.dataset.append((init_latent, filename_tokens, cond)) + self.dataset.append(entry) self.length = len(self.dataset) * repeats @@ -72,10 +91,10 @@ class PersonalizedBase(Dataset): def shuffle(self): self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] - def create_text(self, filename_tokens): + def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) - text = text.replace("[filewords]", ' '.join(filename_tokens)) + text = text.replace("[filewords]", filename_text) return text def __len__(self): @@ -86,7 +105,9 @@ class PersonalizedBase(Dataset): self.shuffle() index = self.indexes[i % len(self.indexes)] - x, filename_tokens, cond = self.dataset[index] + entry = self.dataset[index] + + if entry.cond is None: + entry.cond_text = self.create_text(entry.filename_text) - text = self.create_text(filename_tokens) - return x, text, cond + return entry diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index db720271..2062726a 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -1,6 +1,12 @@ +import tqdm -class LearnSchedule: + +class LearnScheduleIterator: def __init__(self, learn_rate, max_steps, cur_step=0): + """ + specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000 + """ + pairs = learn_rate.split(',') self.rates = [] self.it = 0 @@ -32,3 +38,32 @@ class LearnSchedule: return self.rates[self.it - 1] else: raise StopIteration + + +class LearnRateScheduler: + def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): + self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) + (self.learn_rate, self.end_step) = next(self.schedules) + self.verbose = verbose + + if self.verbose: + print(f'Training at rate of {self.learn_rate} until step {self.end_step}') + + self.finished = False + + def apply(self, optimizer, step_number): + if step_number <= self.end_step: + return + + try: + (self.learn_rate, self.end_step) = next(self.schedules) + except Exception: + self.finished = True + return + + if self.verbose: + tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') + + for pg in optimizer.param_groups: + pg['lr'] = self.learn_rate + diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c5153e4a..fa0e33a2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -from modules.textual_inversion.learn_schedule import LearnSchedule +from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, @@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn - -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if ititial_step > steps: return embedding, filename - schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) - (learn_rate, end_step) = next(schedules) - print(f'Training at rate of {learn_rate} until step {end_step}') - - optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, (x, text, _) in pbar: + for i, entry in pbar: embedding.step = i + ititial_step - if embedding.step > end_step: - try: - (learn_rate, end_step) = next(schedules) - except: - break - tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') - for pg in optimizer.param_groups: - pg['lr'] = learn_rate + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break if shared.state.interrupted: break with torch.autocast("cuda"): - c = cond_model([text]) + c = cond_model([entry.cond_text]) - x = x.to(devices.device) + x = entry.latent.to(devices.device) loss = shared.sd_model(x.unsqueeze(0), c)[0] del x @@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - preview_text = text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini

Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/ui.py b/modules/ui.py index 2b332267..c42535c8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1098,7 +1098,6 @@ def create_ui(wrap_gradio_gpu_call): training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) - num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) @@ -1176,7 +1175,6 @@ def create_ui(wrap_gradio_gpu_call): training_width, training_height, steps, - num_repeats, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From e72adc999b3531370eafb9d316924ac497feb445 Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Sat, 8 Oct 2022 22:57:19 -0500 Subject: Restore last generation params --- .gitignore | 1 + javascript/hints.js | 2 +- modules/generation_parameters_copypaste.py | 8 ++++++++ modules/processing.py | 4 ++++ 4 files changed, 14 insertions(+), 1 deletion(-) (limited to 'javascript/hints.js') diff --git a/.gitignore b/.gitignore index 7afc9395..69785b3e 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,7 @@ __pycache__ /webui.settings.bat /embeddings /styles.csv +/params.txt /styles.csv.bak /webui-user.bat /webui-user.sh diff --git a/javascript/hints.js b/javascript/hints.js index d51ee14c..32f10fde 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -14,7 +14,7 @@ titles = { "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u{1f3a8}": "Add a random artist to the prompt.", - "\u2199\ufe0f": "Read generation parameters from prompt into user interface.", + "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u{1f4c2}": "Open images output directory", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index ac1ba7f4..3e75aecc 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,5 +1,7 @@ +import os import re import gradio as gr +from modules.shared import script_path re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" re_param = re.compile(re_param_code) @@ -61,6 +63,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model def connect_paste(button, paste_fields, input_comp, js=None): def paste_func(prompt): + if not prompt: + filename = os.path.join(script_path, "params.txt") + if os.path.exists(filename): + with open(filename, "r", encoding="utf8") as file: + prompt = file.read() + params = parse_generation_parameters(prompt) res = [] diff --git a/modules/processing.py b/modules/processing.py index 698b3069..d5172f00 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -324,6 +324,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: else: assert p.prompt is not None + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + processed = Processed(p, [], p.seed, "") + file.write(processed.infotext(p, 0)) + devices.torch_gc() seed = get_fixed_seed(p.seed) -- cgit v1.2.3 From bb7baf6b9cb6b4b9fa09b6f07ef997db32fe6e58 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 16:07:18 +0300 Subject: add option to change what's shown in quicksettings bar --- javascript/hints.js | 2 ++ modules/shared.py | 4 ++-- modules/ui.py | 16 +++++++++------- style.css | 1 + 4 files changed, 14 insertions(+), 9 deletions(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index 32f10fde..06bbd9e2 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -84,6 +84,8 @@ titles = { "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", "Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.", + + "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual stetting tab. See modules/shared.py for setting names. Requires restart to apply." } diff --git a/modules/shared.py b/modules/shared.py index 5f6101a4..4d3ed625 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -152,7 +152,6 @@ class OptionInfo: self.component_args = component_args self.onchange = onchange self.section = None - self.show_on_main_page = show_on_main_page def options_section(section_identifier, options_dict): @@ -237,7 +236,7 @@ options_templates.update(options_section(('training', "Training"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), @@ -250,6 +249,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), + 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { diff --git a/modules/ui.py b/modules/ui.py index e07ee0e1..a0529860 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1305,6 +1305,9 @@ Requested path was: {f} settings_cols = 3 items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') + quicksettings_list = [] cols_displayed = 0 @@ -1329,7 +1332,7 @@ Requested path was: {f} gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) - if item.show_on_main_page: + if k in quicksettings_names: quicksettings_list.append((i, k, item)) components.append(dummy_component) else: @@ -1338,7 +1341,11 @@ Requested path was: {f} components.append(component) items_displayed += 1 - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + with gr.Row(): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + request_notifications.click( fn=lambda: None, inputs=[], @@ -1346,10 +1353,6 @@ Requested path was: {f} _js='function(){}' ) - with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - def reload_scripts(): modules.scripts.reload_script_body_only() @@ -1364,7 +1367,6 @@ Requested path was: {f} shared.state.interrupt() settings_interface.gradio_ref.do_restart = True - restart_gradio.click( fn=request_restart, inputs=[], diff --git a/style.css b/style.css index e6fa10b4..55c41971 100644 --- a/style.css +++ b/style.css @@ -488,6 +488,7 @@ input[type="range"]{ #quicksettings > div > div{ max-width: 32em; padding: 0; + margin-right: 0.75em; } canvas[key="mask"] { -- cgit v1.2.3 From dccc181b55100b09182c1679c8dd75011aad7335 Mon Sep 17 00:00:00 2001 From: Taithrah Date: Thu, 13 Oct 2022 10:43:57 -0400 Subject: Update hints.js typo --- javascript/hints.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index 06bbd9e2..f65e7b88 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -85,7 +85,7 @@ titles = { "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", "Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.", - "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual stetting tab. See modules/shared.py for setting names. Requires restart to apply." + "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply." } -- cgit v1.2.3 From 494afccbc1d7b0aca4ffeb3d8354b09c414d95f4 Mon Sep 17 00:00:00 2001 From: crackfoo Date: Thu, 13 Oct 2022 20:26:54 -0700 Subject: Update hints.js typo --- javascript/hints.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index f65e7b88..94438c5c 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -83,7 +83,7 @@ titles = { "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", - "Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.", + "Filename join string": "This string will be used to join split words into a single line if the option above is enabled.", "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply." } -- cgit v1.2.3 From fdecb636855748e03efc40c846a0043800aadfcc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 09:05:06 +0300 Subject: add an ability to merge three checkpoints --- javascript/hints.js | 5 ++++- modules/extras.py | 29 +++++++++++++++++++++-------- modules/ui.py | 11 +++++++---- 3 files changed, 32 insertions(+), 13 deletions(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index 94438c5c..af010a59 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -85,7 +85,10 @@ titles = { "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", "Filename join string": "This string will be used to join split words into a single line if the option above is enabled.", - "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply." + "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.", + + "Weighted Sum": "Result = A * (1 - M) + B * M", + "Add difference": "Result = A + (B - C) * (1 - M)", } diff --git a/modules/extras.py b/modules/extras.py index b24d7de3..532d869f 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -159,48 +159,61 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name): +def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, interp_amount, save_as_half, custom_name): # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) - def weighted_sum(theta0, theta1, alpha): + def weighted_sum(theta0, theta1, theta2, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def sigmoid(theta0, theta1, alpha): + def sigmoid(theta0, theta1, theta2, alpha): alpha = alpha * alpha * (3 - (2 * alpha)) return theta0 + ((theta1 - theta0) * alpha) # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def inv_sigmoid(theta0, theta1, alpha): + def inv_sigmoid(theta0, theta1, theta2, alpha): import math alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) return theta0 + ((theta1 - theta0) * alpha) + def add_difference(theta0, theta1, theta2, alpha): + return theta0 + (theta1 - theta2) * (1.0 - alpha) + primary_model_info = sd_models.checkpoints_list[primary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name] + teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None) print(f"Loading {primary_model_info.filename}...") primary_model = torch.load(primary_model_info.filename, map_location='cpu') + theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) print(f"Loading {secondary_model_info.filename}...") secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') - - theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) + if teritary_model_info is not None: + print(f"Loading {teritary_model_info.filename}...") + teritary_model = torch.load(teritary_model_info.filename, map_location='cpu') + theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model) + else: + theta_2 = None + theta_funcs = { "Weighted Sum": weighted_sum, "Sigmoid": sigmoid, "Inverse Sigmoid": inv_sigmoid, + "Add difference": add_difference, } theta_func = theta_funcs[interp_method] print(f"Merging...") + for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: - theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + theta_0[key] = theta_func(theta_0[key], theta_1[key], theta_2[key] if theta_2 else None, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint if save_as_half: theta_0[key] = theta_0[key].half() + # I believe this part should be discarded, but I'll leave it for now until I am sure for key in theta_1.keys(): if 'model' in key and key not in theta_0: theta_0[key] = theta_1[key] @@ -219,4 +232,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int sd_models.list_models() print(f"Checkpoint saved.") - return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)] + return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] diff --git a/modules/ui.py b/modules/ui.py index 7446439d..220fb80b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1024,11 +1024,12 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) - interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation amount (1 - M)', value=0.3) + interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid", "Add difference"], value="Weighted Sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Save as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') @@ -1473,6 +1474,7 @@ Requested path was: {f} inputs=[ primary_model_name, secondary_model_name, + tertiary_model_name, interp_method, interp_amount, save_as_half, @@ -1482,6 +1484,7 @@ Requested path was: {f} submit_result, primary_model_name, secondary_model_name, + tertiary_model_name, component_dict['sd_model_checkpoint'], ] ) -- cgit v1.2.3 From c250cb289c97fe303cef69064bf45899406f6a40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 22:01:49 +0300 Subject: change checkpoint merger to work in a more obvious way remove sigmoid and inverse sigmoid because they just did the same thing as weighed sum only with changed multiplier --- javascript/hints.js | 4 ++-- modules/extras.py | 24 +++++------------------- modules/ui.py | 4 ++-- 3 files changed, 9 insertions(+), 23 deletions(-) (limited to 'javascript/hints.js') diff --git a/javascript/hints.js b/javascript/hints.js index af010a59..8fec907d 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -87,8 +87,8 @@ titles = { "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.", - "Weighted Sum": "Result = A * (1 - M) + B * M", - "Add difference": "Result = A + (B - C) * (1 - M)", + "Weighted sum": "Result = A * (1 - M) + B * M", + "Add difference": "Result = A + (B - C) * M", } diff --git a/modules/extras.py b/modules/extras.py index 2e7b3751..f2f5a7b0 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -159,24 +159,12 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, interp_amount, save_as_half, custom_name): - # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) +def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): def weighted_sum(theta0, theta1, theta2, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) - # Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def sigmoid(theta0, theta1, theta2, alpha): - alpha = alpha * alpha * (3 - (2 * alpha)) - return theta0 + ((theta1 - theta0) * alpha) - - # Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) - def inv_sigmoid(theta0, theta1, theta2, alpha): - import math - alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0) - return theta0 + ((theta1 - theta0) * alpha) - def add_difference(theta0, theta1, theta2, alpha): - return theta0 + (theta1 - theta2) * (1.0 - alpha) + return theta0 + (theta1 - theta2) * alpha primary_model_info = sd_models.checkpoints_list[primary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name] @@ -198,9 +186,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam theta_2 = None theta_funcs = { - "Weighted Sum": weighted_sum, - "Sigmoid": sigmoid, - "Inverse Sigmoid": inv_sigmoid, + "Weighted sum": weighted_sum, "Add difference": add_difference, } theta_func = theta_funcs[interp_method] @@ -213,7 +199,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam if t2 is None: t2 = torch.zeros_like(theta_0[key]) - theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint + theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier) if save_as_half: theta_0[key] = theta_0[key].half() @@ -227,7 +213,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' + filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' filename = filename if custom_name == '' else (custom_name + '.ckpt') output_modelname = os.path.join(ckpt_dir, filename) diff --git a/modules/ui.py b/modules/ui.py index 4a04c2cc..a08ffc9b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1101,8 +1101,8 @@ def create_ui(wrap_gradio_gpu_call): secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation amount (1 - M)', value=0.3) - interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid", "Add difference"], value="Weighted Sum", label="Interpolation Method") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Save as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') -- cgit v1.2.3