From 70931652a4289e28d83869b6d10cf11e80a70345 Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Fri, 30 Sep 2022 18:02:46 -0700 Subject: [xy_grid] made -1 seed fixing apply to Var. seed too --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 146663b0..9c078888 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -218,7 +218,7 @@ class Script(scripts.Script): ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label == 'Seed': + if axis_opt.label == 'Seed' or 'Var. seed': return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] else: return axis_list -- cgit v1.2.3 From cf141157e7b49b0b3a6e57dc7aa0d1345158b4c8 Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Fri, 30 Sep 2022 22:02:29 -0700 Subject: Added X/Y plot parameters to extra_generation_params --- scripts/xy_grid.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 9c078888..d9f8d55b 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -244,6 +244,14 @@ class Script(scripts.Script): return process_images(pc) + if not x_opt.label == 'Nothing': + p.extra_generation_params["X/Y Plot X Type"] = x_opt.label + p.extra_generation_params["X Values"] = '{' + ", ".join([f'{x}' for x in xs]) + '}' + + if not y_opt.label == 'Nothing': + p.extra_generation_params["X/Y Plot Y Type"] = y_opt.label + p.extra_generation_params["Y Values"] = '{' + ", ".join([f'{y}' for y in ys]) + '}' + processed = draw_xy_grid( p, xs=xs, -- cgit v1.2.3 From eba0c29dbc3bad8c4e32f1fa3a03dc6f9caf1f5a Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sat, 1 Oct 2022 13:56:29 -0700 Subject: Updated xy_grid infotext formatting, parser regex --- modules/generation_parameters_copypaste.py | 2 +- scripts/xy_grid.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index ac1ba7f4..39d67d94 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,7 +1,7 @@ import re import gradio as gr -re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" +re_param_code = r"\s*([\w ]+):\s*((?:{[^}]+})|(?:[^,]+))(?:,|$)" re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index d9f8d55b..f87c6c1f 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -245,12 +245,16 @@ class Script(scripts.Script): return process_images(pc) if not x_opt.label == 'Nothing': - p.extra_generation_params["X/Y Plot X Type"] = x_opt.label - p.extra_generation_params["X Values"] = '{' + ", ".join([f'{x}' for x in xs]) + '}' + p.extra_generation_params["XY Plot X Type"] = x_opt.label + p.extra_generation_params["X Values"] = '{' + x_values + '}' + if x_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: + p.extra_generation_params["Fixed X Values"] = '{' + ", ".join([str(x) for x in xs])+ '}' if not y_opt.label == 'Nothing': - p.extra_generation_params["X/Y Plot Y Type"] = y_opt.label - p.extra_generation_params["Y Values"] = '{' + ", ".join([f'{y}' for y in ys]) + '}' + p.extra_generation_params["XY Plot Y Type"] = y_opt.label + p.extra_generation_params["Y Values"] = '{' + y_values + '}' + if y_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: + p.extra_generation_params["Fixed Y Values"] = '{' + ", ".join([str(y) for y in ys])+ '}' processed = draw_xy_grid( p, -- cgit v1.2.3 From b99a4f769f11ed74df0344a23069d3858613fbef Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sat, 1 Oct 2022 14:26:12 -0700 Subject: fixed expression error in condition --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index f87c6c1f..f1f54d9c 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -218,7 +218,7 @@ class Script(scripts.Script): ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label == 'Seed' or 'Var. seed': + if axis_opt.label in ["Seed","Var. seed"]: return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] else: return axis_list -- cgit v1.2.3 From f6a97868e57e44fba6c4283769fedd30ee11cacf Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sat, 1 Oct 2022 14:36:09 -0700 Subject: fix to allow empty {} values --- modules/generation_parameters_copypaste.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 39d67d94..27d58dfd 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,7 +1,7 @@ import re import gradio as gr -re_param_code = r"\s*([\w ]+):\s*((?:{[^}]+})|(?:[^,]+))(?:,|$)" +re_param_code = r"\s*([\w ]+):\s*((?:{[^}]*})|(?:[^,]+))(?:,|$)" re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") -- cgit v1.2.3 From fe6e2362e8fa5d739de6997ab155a26686d20a49 Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sun, 2 Oct 2022 22:04:28 -0700 Subject: Update xy_grid.py Changed XY Plot infotext value keys to not be so generic. --- scripts/xy_grid.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index f1f54d9c..ae011a17 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -246,15 +246,15 @@ class Script(scripts.Script): if not x_opt.label == 'Nothing': p.extra_generation_params["XY Plot X Type"] = x_opt.label - p.extra_generation_params["X Values"] = '{' + x_values + '}' + p.extra_generation_params["XY Plot X Values"] = '{' + x_values + '}' if x_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: - p.extra_generation_params["Fixed X Values"] = '{' + ", ".join([str(x) for x in xs])+ '}' + p.extra_generation_params["XY Plot Fixed X Values"] = '{' + ", ".join([str(x) for x in xs])+ '}' if not y_opt.label == 'Nothing': p.extra_generation_params["XY Plot Y Type"] = y_opt.label - p.extra_generation_params["Y Values"] = '{' + y_values + '}' + p.extra_generation_params["XY Plot Y Values"] = '{' + y_values + '}' if y_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: - p.extra_generation_params["Fixed Y Values"] = '{' + ", ".join([str(y) for y in ys])+ '}' + p.extra_generation_params["XY Plot Fixed Y Values"] = '{' + ", ".join([str(y) for y in ys])+ '}' processed = draw_xy_grid( p, -- cgit v1.2.3 From 14c1c2b9351f16d43ba4e6b6c9062edad44a6bec Mon Sep 17 00:00:00 2001 From: Alexandre Simard Date: Wed, 19 Oct 2022 13:53:52 -0400 Subject: Show PB texts at same time and earlier For big tasks (1000+ steps), waiting 1 minute to see ETA is long and this changes it so the number of steps done plays a role in showing the text as well. --- modules/ui.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..0abd177a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -261,14 +261,14 @@ def wrap_gradio_call(func, extra_outputs=None): return f -def calc_time_left(progress, threshold, label, force_display): +def calc_time_left(progress, threshold, label, force_display, showTime): if progress == 0: return "" else: time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) eta_relative = eta-time_since_start - if (eta_relative > threshold and progress > 0.02) or force_display: + if (eta_relative > threshold and showTime) or force_display: if eta_relative > 3600: return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) elif eta_relative > 60: @@ -290,7 +290,10 @@ def check_progress_call(id_part): if shared.state.sampling_steps > 0: progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display ) + # Show progress percentage and time left at the same moment, and base it also on steps done + showPBText = progress >= 0.01 or shared.state.sampling_step >= 10 + + time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display, showPBText ) if time_left != "": shared.state.time_left_force_display = True @@ -298,7 +301,7 @@ def check_progress_call(id_part): progressbar = "" if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if progress > 0.01 else ""}
""" + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if showPBText else ""}
""" image = gr_show(False) preview_visibility = gr_show(False) -- cgit v1.2.3 From 4fbdbddc18b21f712acae58bf41740d27023285f Mon Sep 17 00:00:00 2001 From: Alexandre Simard Date: Wed, 19 Oct 2022 15:21:36 -0400 Subject: Remove pad spaces from progress bar text --- javascript/progressbar.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 7a05726e..24ab4795 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -10,7 +10,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ if(progressbar.innerText){ - let newtitle = 'Stable Diffusion - ' + progressbar.innerText + let newtitle = 'Stable Diffusion - ' + progressbar.innerText.slice(2) if(document.title != newtitle){ document.title = newtitle; } -- cgit v1.2.3 From 29e74d6e71826da9a3fe3c5790fed1329fc4d1e8 Mon Sep 17 00:00:00 2001 From: Melan Date: Thu, 20 Oct 2022 16:26:16 +0200 Subject: Add support for Tensorboard for training embeddings --- modules/shared.py | 4 ++++ modules/textual_inversion/textual_inversion.py | 31 +++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index faede821..2c6341f7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -254,6 +254,10 @@ options_templates.update(options_section(('training', "Training"), { "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), + "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."), + "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."), + "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."), + })) options_templates.update(options_section(('sd', "Stable Diffusion"), { diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 3be69562..c57d3ace 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -7,9 +7,11 @@ import tqdm import html import datetime import csv +import numpy as np +import torchvision.transforms from PIL import Image, PngImagePlugin - +from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -199,6 +201,19 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def tensorboard_add_scaler(tensorboard_writer, tag, value, step): + if shared.opts.training_enable_tensorboard: + tensorboard_writer.add_scalar(tag=tag, + scalar_value=value, global_step=step) + +def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): + if shared.opts.training_enable_tensorboard: + # Convert a pil image to a torch tensor + img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) + img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands())) + img_tensor = img_tensor.permute((2, 0, 1)) + + tensorboard_writer.add_image(tag, img_tensor, global_step=step) def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert embedding_name, 'embedding not selected' @@ -252,6 +267,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) + if shared.opts.training_enable_tensorboard: + os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) + tensorboard_writer = SummaryWriter( + log_dir=os.path.join(log_directory, "tensorboard"), + flush_secs=shared.opts.training_tensorboard_flush_every) + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, entries in pbar: embedding.step = i + ititial_step @@ -270,6 +291,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc del x losses[embedding.step % losses.shape[0]] = loss.item() + optimizer.zero_grad() loss.backward() @@ -285,6 +307,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding.save(last_saved_file) embedding_yet_to_be_embedded = True + if shared.opts.training_enable_tensorboard: + tensorboard_add_scaler(tensorboard_writer, "Loss/train", losses.mean(), embedding.step) + tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", losses.mean(), epoch_step) + tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", scheduler.learn_rate, embedding.step) + tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", scheduler.learn_rate, epoch_step) + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { "loss": f"{losses.mean():.7f}", "learn_rate": scheduler.learn_rate @@ -349,6 +377,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = False image.save(last_saved_image) + tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) last_saved_image += f", prompt: {preview_text}" -- cgit v1.2.3 From a6d593a6b51dc6a8443f2aa5c24caa391a04cd56 Mon Sep 17 00:00:00 2001 From: Melan Date: Thu, 20 Oct 2022 19:43:21 +0200 Subject: Fixed a typo in a variable --- modules/textual_inversion/textual_inversion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c57d3ace..ec8176bf 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -260,11 +260,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc last_saved_image = "" embedding_yet_to_be_embedded = False - ititial_step = embedding.step or 0 - if ititial_step > steps: + initial_step = embedding.step or 0 + if initial_step > steps: return embedding, filename - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) if shared.opts.training_enable_tensorboard: @@ -273,9 +273,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc log_dir=os.path.join(log_directory, "tensorboard"), flush_secs=shared.opts.training_tensorboard_flush_every) - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + pbar = tqdm.tqdm(enumerate(ds), total=steps-initial_step) for i, entries in pbar: - embedding.step = i + ititial_step + embedding.step = i + initial_step scheduler.apply(optimizer, embedding.step) if scheduler.finished: -- cgit v1.2.3 From 8f5912984794c4c69e429c4636e984854d911b6a Mon Sep 17 00:00:00 2001 From: Melan Date: Thu, 20 Oct 2022 22:37:16 +0200 Subject: Some changes to the tensorboard code and hypernetwork support --- modules/hypernetworks/hypernetwork.py | 18 ++++++++++- modules/textual_inversion/textual_inversion.py | 45 +++++++++++++++----------- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74300122..5e919775 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -4,6 +4,7 @@ import html import os import sys import traceback +import tensorboard import tqdm import csv @@ -18,7 +19,6 @@ import modules.textual_inversion.dataset from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler - class HypernetworkModule(torch.nn.Module): multiplier = 1.0 @@ -291,6 +291,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) + if shared.opts.training_enable_tensorboard: + tensorboard_writer = textual_inversion.tensorboard_setup(log_directory) + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, entries in pbar: hypernetwork.step = i + ititial_step @@ -315,6 +318,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer.zero_grad() loss.backward() optimizer.step() + mean_loss = losses.mean() if torch.isnan(mean_loss): raise RuntimeError("Loss diverged.") @@ -323,6 +327,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') hypernetwork.save(last_saved_file) + + if shared.opts.training_enable_tensorboard: + epoch_num = hypernetwork.step // len(ds) + epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 + + textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, + global_step=hypernetwork.step, step=epoch_step, + learn_rate=scheduler.learn_rate, epoch_num=epoch_num) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{mean_loss:.7f}", @@ -360,6 +372,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log processed = processing.process_images(p) image = processed.images[0] if len(processed.images)>0 else None + if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", + image, hypernetwork.step) + if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ec8176bf..b1dc2596 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -201,19 +201,30 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def tensorboard_setup(log_directory): + os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) + return SummaryWriter( + log_dir=os.path.join(log_directory, "tensorboard"), + flush_secs=shared.opts.training_tensorboard_flush_every) + +def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num): + tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step) + tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step) + tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step) + tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step) + def tensorboard_add_scaler(tensorboard_writer, tag, value, step): - if shared.opts.training_enable_tensorboard: - tensorboard_writer.add_scalar(tag=tag, - scalar_value=value, global_step=step) + tensorboard_writer.add_scalar(tag=tag, + scalar_value=value, global_step=step) def tensorboard_add_image(tensorboard_writer, tag, pil_image, step): - if shared.opts.training_enable_tensorboard: - # Convert a pil image to a torch tensor - img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) - img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands())) - img_tensor = img_tensor.permute((2, 0, 1)) + # Convert a pil image to a torch tensor + img_tensor = torch.as_tensor(np.array(pil_image, copy=True)) + img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], + len(pil_image.getbands())) + img_tensor = img_tensor.permute((2, 0, 1)) - tensorboard_writer.add_image(tag, img_tensor, global_step=step) + tensorboard_writer.add_image(tag, img_tensor, global_step=step) def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert embedding_name, 'embedding not selected' @@ -268,10 +279,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) if shared.opts.training_enable_tensorboard: - os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True) - tensorboard_writer = SummaryWriter( - log_dir=os.path.join(log_directory, "tensorboard"), - flush_secs=shared.opts.training_tensorboard_flush_every) + tensorboard_writer = tensorboard_setup(log_directory) pbar = tqdm.tqdm(enumerate(ds), total=steps-initial_step) for i, entries in pbar: @@ -308,10 +316,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = True if shared.opts.training_enable_tensorboard: - tensorboard_add_scaler(tensorboard_writer, "Loss/train", losses.mean(), embedding.step) - tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", losses.mean(), epoch_step) - tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", scheduler.learn_rate, embedding.step) - tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", scheduler.learn_rate, epoch_step) + tensorboard_add(tensorboard_writer, loss=losses.mean(), global_step=embedding.step, + step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num) write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { "loss": f"{losses.mean():.7f}", @@ -377,7 +383,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc embedding_yet_to_be_embedded = False image.save(last_saved_image) - tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step) + + if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", + image, embedding.step) last_saved_image += f", prompt: {preview_text}" -- cgit v1.2.3 From 7543cf5e3b5eaced00582da257801227d1ff2a6e Mon Sep 17 00:00:00 2001 From: Melan Date: Thu, 20 Oct 2022 22:43:08 +0200 Subject: Fixed some typos in the code --- modules/hypernetworks/hypernetwork.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 5e919775..0cd94f49 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -284,19 +284,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log last_saved_file = "" last_saved_image = "" - ititial_step = hypernetwork.step or 0 - if ititial_step > steps: + initial_step = hypernetwork.step or 0 + if initial_step > steps: return hypernetwork, filename - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) if shared.opts.training_enable_tensorboard: tensorboard_writer = textual_inversion.tensorboard_setup(log_directory) - pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) + pbar = tqdm.tqdm(enumerate(ds), total=steps - initial_step) for i, entries in pbar: - hypernetwork.step = i + ititial_step + hypernetwork.step = i + initial_step scheduler.apply(optimizer, hypernetwork.step) if scheduler.finished: -- cgit v1.2.3 From 18f86e41f6f289042c075bff1498e620ab997b8c Mon Sep 17 00:00:00 2001 From: Melan Date: Mon, 24 Oct 2022 17:21:18 +0200 Subject: Removed two unused imports --- modules/hypernetworks/hypernetwork.py | 1 - modules/textual_inversion/textual_inversion.py | 1 - 2 files changed, 2 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 0cd94f49..2263e95e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -4,7 +4,6 @@ import html import os import sys import traceback -import tensorboard import tqdm import csv diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b1dc2596..589314fe 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -9,7 +9,6 @@ import datetime import csv import numpy as np -import torchvision.transforms from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, processing, sd_models -- cgit v1.2.3 From c4b5ca5778340b21288d84dfb8fe1d5773c886a8 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Thu, 27 Oct 2022 22:00:28 +0900 Subject: Truncate too long filename --- modules/images.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/modules/images.py b/modules/images.py index 7870b5b7..42363ed3 100644 --- a/modules/images.py +++ b/modules/images.py @@ -416,6 +416,14 @@ def get_next_sequence_number(path, basename): return result + 1 +def truncate_fullpath(full_path, encoding='utf-8'): + dir_name, full_name = os.path.split(full_path) + file_name, file_ext = os.path.splitext(full_name) + max_length = os.statvfs(dir_name).f_namemax + file_name_truncated = file_name.encode(encoding)[:max_length - len(file_ext)].decode(encoding, 'ignore') + return os.path.join(dir_name , file_name_truncated + file_ext) + + def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None): """Save an image. @@ -456,7 +464,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i if save_to_dirs: dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /') - path = os.path.join(path, dirname) + path = truncate_fullpath(os.path.join(path, dirname)) os.makedirs(path, exist_ok=True) @@ -480,13 +488,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i fullfn = None for i in range(500): fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" - fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") + fullfn = truncate_fullpath(os.path.join(path, f"{fn}{file_decoration}.{extension}")) if not os.path.exists(fullfn): break else: - fullfn = os.path.join(path, f"{file_decoration}.{extension}") + fullfn = truncate_fullpath(os.path.join(path, f"{file_decoration}.{extension}")) else: - fullfn = os.path.join(path, f"{forced_filename}.{extension}") + fullfn = truncate_fullpath(os.path.join(path, f"{forced_filename}.{extension}")) pnginfo = existing_info or {} if info is not None: -- cgit v1.2.3 From 2a25729623717cc499e873752d9f4ebebd1e1078 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Fri, 28 Oct 2022 09:44:56 +0700 Subject: Gradient clipping in train tab --- modules/hypernetworks/hypernetwork.py | 10 +++++++++- modules/ui.py | 7 +++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8113b35b..c5d60654 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -327,7 +327,7 @@ def report_statistics(loss_info:dict): -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -384,6 +384,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if ititial_step > steps: return hypernetwork, filename + clip_grad_mode_value = clip_grad_mode == "value" + clip_grad_mode_norm = clip_grad_mode == "norm" + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) @@ -426,6 +429,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log steps_without_grad = 0 assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' + if clip_grad_mode_value: + torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_value) + elif clip_grad_mode_norm: + torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_value) + optimizer.step() if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): diff --git a/modules/ui.py b/modules/ui.py index 0a63e357..97de7da2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1313,6 +1313,9 @@ def create_ui(wrap_gradio_gpu_call): training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) + with gr.Row(): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Number(value=1.0, show_label=False) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) @@ -1406,6 +1409,8 @@ def create_ui(wrap_gradio_gpu_call): training_width, training_height, steps, + clip_grad_mode, + clip_grad_value, create_image_every, save_embedding_every, template_file, @@ -1431,6 +1436,8 @@ def create_ui(wrap_gradio_gpu_call): training_width, training_height, steps, + clip_grad_mode, + clip_grad_value, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From a133042c669f666763f5da0f4440abdc839db653 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Fri, 28 Oct 2022 10:01:46 +0700 Subject: Forgot to remove this from train_embedding --- modules/ui.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 97de7da2..ba5e92a7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1409,8 +1409,6 @@ def create_ui(wrap_gradio_gpu_call): training_width, training_height, steps, - clip_grad_mode, - clip_grad_value, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From 1618df41bad092e068c61bf510b1e20856821ad5 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Fri, 28 Oct 2022 10:31:27 +0700 Subject: Gradient clipping for textual embedding --- modules/textual_inversion/textual_inversion.py | 11 ++++++++++- modules/ui.py | 2 ++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ff002d3e..7bad73a6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -206,7 +206,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) -def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -256,6 +256,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if ititial_step > steps: return embedding, filename + clip_grad_mode_value = clip_grad_mode == "value" + clip_grad_mode_norm = clip_grad_mode == "norm" + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) @@ -280,6 +283,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc optimizer.zero_grad() loss.backward() + + if clip_grad_mode_value: + torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_value) + elif clip_grad_mode_norm: + torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_value) + optimizer.step() diff --git a/modules/ui.py b/modules/ui.py index ba5e92a7..97de7da2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1409,6 +1409,8 @@ def create_ui(wrap_gradio_gpu_call): training_width, training_height, steps, + clip_grad_mode, + clip_grad_value, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From 16451ca573220e49f2eaaab97580b6b91287c8c4 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Fri, 28 Oct 2022 17:16:23 +0700 Subject: Learning rate sched syntax support for grad clipping --- modules/hypernetworks/hypernetwork.py | 13 ++++++++++--- modules/textual_inversion/learn_schedule.py | 11 ++++++++--- modules/textual_inversion/textual_inversion.py | 12 +++++++++--- modules/ui.py | 7 +++---- 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c5d60654..86532063 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -383,11 +383,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ititial_step = hypernetwork.step or 0 if ititial_step > steps: return hypernetwork, filename - + clip_grad_mode_value = clip_grad_mode == "value" clip_grad_mode_norm = clip_grad_mode == "norm" + clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm + if clip_grad_enabled: + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc... optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) @@ -407,6 +411,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if shared.state.interrupted: break + if clip_grad_enabled: + clip_grad_sched.step(hypernetwork.step) + with torch.autocast("cuda"): c = stack_conds([entry.cond for entry in entries]).to(devices.device) # c = torch.vstack([entry.cond for entry in entries]).to(devices.device) @@ -430,9 +437,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' if clip_grad_mode_value: - torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_value) + torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate) elif clip_grad_mode_norm: - torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_value) + torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate) optimizer.step() diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index 2062726a..ffec3e1b 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -51,14 +51,19 @@ class LearnRateScheduler: self.finished = False - def apply(self, optimizer, step_number): + def step(self, step_number): if step_number <= self.end_step: - return + return False try: (self.learn_rate, self.end_step) = next(self.schedules) - except Exception: + except StopIteration: self.finished = True + return False + return True + + def apply(self, optimizer, step_number): + if not self.step(step_number): return if self.verbose: diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7bad73a6..6b00c6a1 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -255,9 +255,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ititial_step = embedding.step or 0 if ititial_step > steps: return embedding, filename - + clip_grad_mode_value = clip_grad_mode == "value" clip_grad_mode_norm = clip_grad_mode == "norm" + clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm + if clip_grad_enabled: + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) @@ -273,6 +276,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if shared.state.interrupted: break + if clip_grad_enabled: + clip_grad_sched.step(embedding.step) + with torch.autocast("cuda"): c = cond_model([entry.cond_text for entry in entries]) x = torch.stack([entry.latent for entry in entries]).to(devices.device) @@ -285,9 +291,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc loss.backward() if clip_grad_mode_value: - torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_value) + torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate) elif clip_grad_mode_norm: - torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_value) + torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate) optimizer.step() diff --git a/modules/ui.py b/modules/ui.py index 97de7da2..47d16429 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1305,7 +1305,9 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") - + with gr.Row(): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="1.0", show_label=False) batch_size = gr.Number(label='Batch size', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") @@ -1313,9 +1315,6 @@ def create_ui(wrap_gradio_gpu_call): training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) - with gr.Row(): - clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Number(value=1.0, show_label=False) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) -- cgit v1.2.3 From 840307f23738c38f7ac3ad636e53ccec66e71f8b Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 13:49:24 +0700 Subject: Change default clip grad value to 0.1 It still defaults to disabled. Ref for value: https://github.com/danielalcalde/stable-diffusion-webui/commit/732b15820a9bde9f47e075a6209c3d47d47acb08 --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 98f9565f..364953aa 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1256,7 +1256,7 @@ def create_ui(wrap_gradio_gpu_call): hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") with gr.Row(): clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="1.0", show_label=False) + clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) batch_size = gr.Number(label='Batch size', value=1, precision=0) dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") -- cgit v1.2.3 From 4123be632a98f70cda06e14c2f556f7ad38cd436 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 13:53:22 +0700 Subject: Fix merge conflicts --- modules/hypernetworks/hypernetwork.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 65a584bb..207808ee 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -373,6 +373,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + clip_grad_mode_value = clip_grad_mode == "value" + clip_grad_mode_norm = clip_grad_mode == "norm" + clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm + if clip_grad_enabled: + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): @@ -389,21 +395,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) - last_saved_file = "" - last_saved_image = "" - forced_filename = "" - ititial_step = hypernetwork.step or 0 if ititial_step > steps: return hypernetwork, filename - clip_grad_mode_value = clip_grad_mode == "value" - clip_grad_mode_norm = clip_grad_mode == "norm" - clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm - if clip_grad_enabled: - clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) - - scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) weights = hypernetwork.weights() for weight in weights: -- cgit v1.2.3 From d5ea878b2aa117588d85287cbd8983aa52177df5 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 13:54:40 +0700 Subject: Fix merge conflicts --- modules/hypernetworks/hypernetwork.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 207808ee..2df38c70 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -395,11 +395,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log previous_mean_loss = 0 print("Mean loss of {} elements".format(size)) - ititial_step = hypernetwork.step or 0 - if ititial_step > steps: - return hypernetwork, filename - - weights = hypernetwork.weights() for weight in weights: weight.requires_grad = True -- cgit v1.2.3 From cffc240a7327ae60671ff533469fc4ed4bf605de Mon Sep 17 00:00:00 2001 From: Nerogar Date: Sun, 23 Oct 2022 14:05:25 +0200 Subject: fixed textual inversion training with inpainting models --- modules/textual_inversion/textual_inversion.py | 27 +++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 0aeb0459..2630c7c9 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -224,6 +224,26 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" +def create_dummy_mask(x, width=None, height=None): + if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}: + + # The "masked-image" in this case will just be all zeros since the entire image is masked. + image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) + image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + else: + # Dummy zero conditioning if we're not using inpainting model. + # Still takes up a bit of memory, but no encoder call. + # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. + image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) + + return image_conditioning + + def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 @@ -286,6 +306,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc forced_filename = "" embedding_yet_to_be_embedded = False + img_c = None pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) for i, entries in pbar: embedding.step = i + ititial_step @@ -299,8 +320,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc with torch.autocast("cuda"): c = cond_model([entry.cond_text for entry in entries]) + if img_c is None: + img_c = create_dummy_mask(c, training_width, training_height) + x = torch.stack([entry.latent for entry in entries]).to(devices.device) - loss = shared.sd_model(x, c)[0] + cond = {"c_concat": [img_c], "c_crossattn": [c]} + loss = shared.sd_model(x, cond)[0] del x losses[embedding.step % losses.shape[0]] = loss.item() -- cgit v1.2.3 From d624cb82a7c65a1ea04e4b6e23f0164a3ba25e25 Mon Sep 17 00:00:00 2001 From: Ikko Ashimine Date: Thu, 3 Nov 2022 01:05:00 +0900 Subject: Fix typo in ui.js interation -> interaction --- javascript/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/ui.js b/javascript/ui.js index 7e116465..0308dce3 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -1,4 +1,4 @@ -// various functions for interation with ui.py not large enough to warrant putting them in separate files +// various functions for interaction with ui.py not large enough to warrant putting them in separate files function set_theme(theme){ gradioURL = window.location.href -- cgit v1.2.3 From bb832d7725187f8a8ab44faa6ee1b38cb5f600aa Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 5 Nov 2022 11:48:38 +0700 Subject: Simplify grad clip --- modules/hypernetworks/hypernetwork.py | 16 +++++++--------- modules/textual_inversion/textual_inversion.py | 16 +++++++--------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index f4c2668f..02b624e1 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -385,10 +385,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - clip_grad_mode_value = clip_grad_mode == "value" - clip_grad_mode_norm = clip_grad_mode == "norm" - clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm - if clip_grad_enabled: + clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ + torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ + None + if clip_grad: clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this @@ -433,7 +433,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if shared.state.interrupted: break - if clip_grad_enabled: + if clip_grad: clip_grad_sched.step(hypernetwork.step) with torch.autocast("cuda"): @@ -458,10 +458,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log steps_without_grad = 0 assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue' - if clip_grad_mode_value: - torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate) - elif clip_grad_mode_norm: - torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate) + if clip_grad: + clip_grad(weights, clip_grad_sched.learn_rate) optimizer.step() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c567ec3f..687d97bb 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -269,10 +269,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) - clip_grad_mode_value = clip_grad_mode == "value" - clip_grad_mode_norm = clip_grad_mode == "norm" - clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm - if clip_grad_enabled: + clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ + torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ + None + if clip_grad: clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." @@ -302,7 +302,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if shared.state.interrupted: break - if clip_grad_enabled: + if clip_grad: clip_grad_sched.step(embedding.step) with torch.autocast("cuda"): @@ -316,10 +316,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc optimizer.zero_grad() loss.backward() - if clip_grad_mode_value: - torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate) - elif clip_grad_mode_norm: - torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate) + if clip_grad: + clip_grad(embedding.vec, clip_grad_sched.learn_rate) optimizer.step() -- cgit v1.2.3 From 75c4511e6b81ae8fb0dbd932043e8eb35cd09f72 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 29 Nov 2022 10:28:41 +0800 Subject: add AltDiffusion to webui Signed-off-by: zhaohu xing <920232796@qq.com> --- configs/altdiffusion/ad-inference.yaml | 72 ++ configs/stable-diffusion/v1-inference.yaml | 71 ++ ldm/data/__init__.py | 0 ldm/data/base.py | 23 + ldm/data/imagenet.py | 394 +++++++ ldm/data/lsun.py | 92 ++ ldm/lr_scheduler.py | 98 ++ ldm/models/autoencoder.py | 443 ++++++++ ldm/models/diffusion/__init__.py | 0 ldm/models/diffusion/classifier.py | 267 +++++ ldm/models/diffusion/ddim.py | 241 +++++ ldm/models/diffusion/ddpm.py | 1445 +++++++++++++++++++++++++ ldm/models/diffusion/dpm_solver/__init__.py | 1 + ldm/models/diffusion/dpm_solver/dpm_solver.py | 1184 ++++++++++++++++++++ ldm/models/diffusion/dpm_solver/sampler.py | 82 ++ ldm/models/diffusion/plms.py | 236 ++++ ldm/modules/attention.py | 261 +++++ ldm/modules/diffusionmodules/__init__.py | 0 ldm/modules/diffusionmodules/model.py | 835 ++++++++++++++ ldm/modules/diffusionmodules/openaimodel.py | 961 ++++++++++++++++ ldm/modules/diffusionmodules/util.py | 267 +++++ ldm/modules/distributions/__init__.py | 0 ldm/modules/distributions/distributions.py | 92 ++ ldm/modules/ema.py | 76 ++ ldm/modules/encoders/__init__.py | 0 ldm/modules/encoders/modules.py | 234 ++++ ldm/modules/encoders/xlmr.py | 137 +++ ldm/modules/image_degradation/__init__.py | 2 + ldm/modules/image_degradation/bsrgan.py | 730 +++++++++++++ ldm/modules/image_degradation/bsrgan_light.py | 650 +++++++++++ ldm/modules/image_degradation/utils/test.png | Bin 0 -> 441072 bytes ldm/modules/image_degradation/utils_image.py | 916 ++++++++++++++++ ldm/modules/losses/__init__.py | 1 + ldm/modules/losses/contperceptual.py | 111 ++ ldm/modules/losses/vqperceptual.py | 167 +++ ldm/modules/x_transformer.py | 641 +++++++++++ ldm/util.py | 203 ++++ modules/devices.py | 4 +- modules/sd_hijack.py | 23 +- modules/shared.py | 6 +- 40 files changed, 10957 insertions(+), 9 deletions(-) create mode 100644 configs/altdiffusion/ad-inference.yaml create mode 100644 configs/stable-diffusion/v1-inference.yaml create mode 100644 ldm/data/__init__.py create mode 100644 ldm/data/base.py create mode 100644 ldm/data/imagenet.py create mode 100644 ldm/data/lsun.py create mode 100644 ldm/lr_scheduler.py create mode 100644 ldm/models/autoencoder.py create mode 100644 ldm/models/diffusion/__init__.py create mode 100644 ldm/models/diffusion/classifier.py create mode 100644 ldm/models/diffusion/ddim.py create mode 100644 ldm/models/diffusion/ddpm.py create mode 100644 ldm/models/diffusion/dpm_solver/__init__.py create mode 100644 ldm/models/diffusion/dpm_solver/dpm_solver.py create mode 100644 ldm/models/diffusion/dpm_solver/sampler.py create mode 100644 ldm/models/diffusion/plms.py create mode 100644 ldm/modules/attention.py create mode 100644 ldm/modules/diffusionmodules/__init__.py create mode 100644 ldm/modules/diffusionmodules/model.py create mode 100644 ldm/modules/diffusionmodules/openaimodel.py create mode 100644 ldm/modules/diffusionmodules/util.py create mode 100644 ldm/modules/distributions/__init__.py create mode 100644 ldm/modules/distributions/distributions.py create mode 100644 ldm/modules/ema.py create mode 100644 ldm/modules/encoders/__init__.py create mode 100644 ldm/modules/encoders/modules.py create mode 100644 ldm/modules/encoders/xlmr.py create mode 100644 ldm/modules/image_degradation/__init__.py create mode 100644 ldm/modules/image_degradation/bsrgan.py create mode 100644 ldm/modules/image_degradation/bsrgan_light.py create mode 100644 ldm/modules/image_degradation/utils/test.png create mode 100644 ldm/modules/image_degradation/utils_image.py create mode 100644 ldm/modules/losses/__init__.py create mode 100644 ldm/modules/losses/contperceptual.py create mode 100644 ldm/modules/losses/vqperceptual.py create mode 100644 ldm/modules/x_transformer.py create mode 100644 ldm/util.py diff --git a/configs/altdiffusion/ad-inference.yaml b/configs/altdiffusion/ad-inference.yaml new file mode 100644 index 00000000..1b11b63e --- /dev/null +++ b/configs/altdiffusion/ad-inference.yaml @@ -0,0 +1,72 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.xlmr.BertSeriesModelWithTransformation + params: + name: "XLMR-Large" \ No newline at end of file diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml new file mode 100644 index 00000000..2e6ef0f2 --- /dev/null +++ b/configs/stable-diffusion/v1-inference.yaml @@ -0,0 +1,71 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + # target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: altclip.model.AltCLIPEmbedder \ No newline at end of file diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/data/base.py b/ldm/data/base.py new file mode 100644 index 00000000..b196c2f7 --- /dev/null +++ b/ldm/data/base.py @@ -0,0 +1,23 @@ +from abc import abstractmethod +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass \ No newline at end of file diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py new file mode 100644 index 00000000..1c473f9c --- /dev/null +++ b/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/ldm/data/lsun.py b/ldm/data/lsun.py new file mode 100644 index 00000000..6256e457 --- /dev/null +++ b/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py new file mode 100644 index 00000000..be39da9c --- /dev/null +++ b/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py new file mode 100644 index 00000000..6a9c4f45 --- /dev/null +++ b/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py new file mode 100644 index 00000000..67e98b9d --- /dev/null +++ b/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py new file mode 100644 index 00000000..fb31215d --- /dev/null +++ b/ldm/models/diffusion/ddim.py @@ -0,0 +1,241 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ + extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec \ No newline at end of file diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py new file mode 100644 index 00000000..bbedd04c --- /dev/null +++ b/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1445 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox']: + xc = batch[cond_key] + elif cond_key == 'class_label': + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + # import pudb; pudb.set_trace() + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, **kwargs): + + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 00000000..7427f38c --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 00000000..bdb64e0c --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1184 @@ +import torch +import torch.nn.functional as F +import math + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * (K - 1) + [1] + else: + orders = [3,] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2,] * K + else: + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1,] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in range(1, order): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order,] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 00000000..2c42d6f9 --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,82 @@ +"""SAMPLING ONLY.""" + +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type="noise", + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py new file mode 100644 index 00000000..78eeb100 --- /dev/null +++ b/ldm/models/diffusion/plms.py @@ -0,0 +1,236 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py new file mode 100644 index 00000000..f4eff39c --- /dev/null +++ b/ldm/modules/attention.py @@ -0,0 +1,261 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): + super().__init__() + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c') + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) + x = self.proj_out(x) + return x + x_in \ No newline at end of file diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py new file mode 100644 index 00000000..533e589a --- /dev/null +++ b/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,835 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 00000000..fcf95d1e --- /dev/null +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,961 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py new file mode 100644 index 00000000..a952e6c4 --- /dev/null +++ b/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,267 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py new file mode 100644 index 00000000..f2b8ef90 --- /dev/null +++ b/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py new file mode 100644 index 00000000..c8c75af4 --- /dev/null +++ b/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py new file mode 100644 index 00000000..ededbe43 --- /dev/null +++ b/ldm/modules/encoders/modules.py @@ -0,0 +1,234 @@ +import torch +import torch.nn as nn +from functools import partial +import clip +from einops import rearrange, repeat +from transformers import CLIPTokenizer, CLIPTextModel +import kornia + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPTextEmbedder(nn.Module): + """ + Uses the CLIP transformer encoder for text. + """ + def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + super().__init__() + self.model, _ = clip.load(version, jit=False, device="cpu") + self.device = device + self.max_length = max_length + self.n_repeat = n_repeat + self.normalize = normalize + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = clip.tokenize(text).to(self.device) + z = self.model.encode_text(tokens) + if self.normalize: + z = z / torch.linalg.norm(z, dim=1, keepdim=True) + return z + + def encode(self, text): + z = self(text) + if z.ndim==2: + z = z[:, None, :] + z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) + return z + + +class FrozenClipImageEmbedder(nn.Module): + """ + Uses the CLIP image encoder. + """ + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + return self.model.encode_image(self.preprocess(x)) + + +if __name__ == "__main__": + from ldm.util import count_params + model = FrozenCLIPEmbedder() + count_params(model, verbose=True) \ No newline at end of file diff --git a/ldm/modules/encoders/xlmr.py b/ldm/modules/encoders/xlmr.py new file mode 100644 index 00000000..beab3fdf --- /dev/null +++ b/ldm/modules/encoders/xlmr.py @@ -0,0 +1,137 @@ +from transformers import BertPreTrainedModel,BertModel,BertConfig +import torch.nn as nn +import torch +from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +from transformers import XLMRobertaModel,XLMRobertaTokenizer +from typing import Optional + +class BertSeriesConfig(BertConfig): + def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): + + super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + + +class BertSeriesModelWithTransformation(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + config_class = BertSeriesConfig + + def __init__(self, config=None, **kargs): + # modify initialization for autoloading + if config is None: + config = XLMRobertaConfig() + config.attention_probs_dropout_prob= 0.1 + config.bos_token_id=0 + config.eos_token_id=2 + config.hidden_act='gelu' + config.hidden_dropout_prob=0.1 + config.hidden_size=1024 + config.initializer_range=0.02 + config.intermediate_size=4096 + config.layer_norm_eps=1e-05 + config.max_position_embeddings=514 + + config.num_attention_heads=16 + config.num_hidden_layers=24 + config.output_past=True + config.pad_token_id=1 + config.position_embedding_type= "absolute" + + config.type_vocab_size= 1 + config.use_cache=True + config.vocab_size= 250002 + config.project_dim = 768 + config.learn_encoder = False + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size,config.project_dim) + self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + self.pooler = lambda x: x[:,0] + self.post_init() + + def encode(self,c): + device = next(self.parameters()).device + text = self.tokenizer(c, + truncation=True, + max_length=77, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + text["input_ids"] = torch.tensor(text["input_ids"]).to(device) + text["attention_mask"] = torch.tensor( + text['attention_mask']).to(device) + features = self(**text) + return features['projection_state'] + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) : + r""" + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + # last module outputs + sequence_output = outputs[0] + + + # project every module + sequence_output_ln = self.pre_LN(sequence_output) + + # pooler + pooler_output = self.pooler(sequence_output_ln) + pooler_output = self.transformation(pooler_output) + projection_state = self.transformation(outputs.last_hidden_state) + + return { + 'pooler_output':pooler_output, + 'last_hidden_state':outputs.last_hidden_state, + 'hidden_states':outputs.hidden_states, + 'attentions':outputs.attentions, + 'projection_state':projection_state, + 'sequence_out': sequence_output + } + + +class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): + base_model_prefix = 'roberta' + config_class= RobertaSeriesConfig \ No newline at end of file diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py new file mode 100644 index 00000000..7836cada --- /dev/null +++ b/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 00000000..32ef5616 --- /dev/null +++ b/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 00000000..9e1f8239 --- /dev/null +++ b/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 00000000..4249b43d Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 00000000..0175f155 --- /dev/null +++ b/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/ldm/modules/losses/__init__.py b/ldm/modules/losses/__init__.py new file mode 100644 index 00000000..876d7c5b --- /dev/null +++ b/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py new file mode 100644 index 00000000..672c1e32 --- /dev/null +++ b/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py new file mode 100644 index 00000000..f6998176 --- /dev/null +++ b/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py new file mode 100644 index 00000000..5fc15bf9 --- /dev/null +++ b/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/ldm/util.py b/ldm/util.py new file mode 100644 index 00000000..8ba38853 --- /dev/null +++ b/ldm/util.py @@ -0,0 +1,203 @@ +import importlib + +import torch +import numpy as np +from collections import abc +from einops import rearrange +from functools import partial + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i: i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == 'ndarray': + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == 'list': + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res diff --git a/modules/devices.py b/modules/devices.py index 67165bf6..f30b6ebc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -36,8 +36,8 @@ def get_optimal_device(): else: return torch.device("cuda") - if has_mps(): - return torch.device("mps") + # if has_mps(): + # return torch.device("mps") return cpu diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..26280fe4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -70,14 +70,19 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + + if shared.text_model_name == "XLMR-Large": + model_embeddings = m.cond_stage_model.roberta.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) + else : + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embeddings, self) - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model - apply_optimizations() + # apply_optimizations() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -125,8 +130,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.tokenizer = wrapped.tokenizer self.token_mults = {} - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - + try: + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + except: + self.comma_token = None + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 @@ -298,6 +306,9 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): + if shared.text_model_name == "XLMR-Large": + return self.wrapped.encode(text) + use_old = opts.use_old_emphasis_implementation if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) @@ -359,7 +370,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z = self.wrapped.transformer.text_model.final_layer_norm(z) else: z = outputs.last_hidden_state - + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) diff --git a/modules/shared.py b/modules/shared.py index c93ae2a3..9941d2f4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -21,7 +21,7 @@ from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) +parser.add_argument("--config", type=str, default="configs/altdiffusion/ad-inference.yaml", help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) @@ -106,6 +106,10 @@ restricted_opts = { "outdir_txt2img_grids", "outdir_save", } +from omegaconf import OmegaConf +config = OmegaConf.load(f"{cmd_opts.config}") +# XLMR-Large +text_model_name = config.model.params.cond_stage_config.params.name cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access -- cgit v1.2.3 From ee3f5ea3eeb31f1ed72e2f0cbed2c00a782497d8 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 29 Nov 2022 10:30:19 +0800 Subject: delete old config file Signed-off-by: zhaohu xing <920232796@qq.com> --- configs/stable-diffusion/v1-inference.yaml | 71 ------------------------------ 1 file changed, 71 deletions(-) delete mode 100644 configs/stable-diffusion/v1-inference.yaml diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml deleted file mode 100644 index 2e6ef0f2..00000000 --- a/configs/stable-diffusion/v1-inference.yaml +++ /dev/null @@ -1,71 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - # target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - target: altclip.model.AltCLIPEmbedder \ No newline at end of file -- cgit v1.2.3 From 52cc83d36b7663a77b79fd2258d2ca871af73e55 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Wed, 30 Nov 2022 14:56:12 +0800 Subject: fix bugs Signed-off-by: zhaohu xing <920232796@qq.com> --- configs/altdiffusion/ad-inference.yaml | 2 +- launch.py | 10 +- ldm/data/__init__.py | 0 ldm/data/base.py | 23 - ldm/data/imagenet.py | 394 ------- ldm/data/lsun.py | 92 -- ldm/lr_scheduler.py | 98 -- ldm/models/autoencoder.py | 443 -------- ldm/models/diffusion/__init__.py | 0 ldm/models/diffusion/classifier.py | 267 ----- ldm/models/diffusion/ddim.py | 241 ----- ldm/models/diffusion/ddpm.py | 1445 ------------------------- ldm/models/diffusion/dpm_solver/__init__.py | 1 - ldm/models/diffusion/dpm_solver/dpm_solver.py | 1184 -------------------- ldm/models/diffusion/dpm_solver/sampler.py | 82 -- ldm/models/diffusion/plms.py | 236 ---- ldm/modules/attention.py | 261 ----- ldm/modules/diffusionmodules/__init__.py | 0 ldm/modules/diffusionmodules/model.py | 835 -------------- ldm/modules/diffusionmodules/openaimodel.py | 961 ---------------- ldm/modules/diffusionmodules/util.py | 267 ----- ldm/modules/distributions/__init__.py | 0 ldm/modules/distributions/distributions.py | 92 -- ldm/modules/ema.py | 76 -- ldm/modules/encoders/__init__.py | 0 ldm/modules/encoders/modules.py | 234 ---- ldm/modules/encoders/xlmr.py | 137 --- ldm/modules/image_degradation/__init__.py | 2 - ldm/modules/image_degradation/bsrgan.py | 730 ------------- ldm/modules/image_degradation/bsrgan_light.py | 650 ----------- ldm/modules/image_degradation/utils/test.png | Bin 441072 -> 0 bytes ldm/modules/image_degradation/utils_image.py | 916 ---------------- ldm/modules/losses/__init__.py | 1 - ldm/modules/losses/contperceptual.py | 111 -- ldm/modules/losses/vqperceptual.py | 167 --- ldm/modules/x_transformer.py | 641 ----------- ldm/util.py | 203 ---- modules/sd_hijack.py | 15 +- modules/sd_hijack_clip.py | 10 +- modules/xlmr.py | 137 +++ 40 files changed, 159 insertions(+), 10805 deletions(-) delete mode 100644 ldm/data/__init__.py delete mode 100644 ldm/data/base.py delete mode 100644 ldm/data/imagenet.py delete mode 100644 ldm/data/lsun.py delete mode 100644 ldm/lr_scheduler.py delete mode 100644 ldm/models/autoencoder.py delete mode 100644 ldm/models/diffusion/__init__.py delete mode 100644 ldm/models/diffusion/classifier.py delete mode 100644 ldm/models/diffusion/ddim.py delete mode 100644 ldm/models/diffusion/ddpm.py delete mode 100644 ldm/models/diffusion/dpm_solver/__init__.py delete mode 100644 ldm/models/diffusion/dpm_solver/dpm_solver.py delete mode 100644 ldm/models/diffusion/dpm_solver/sampler.py delete mode 100644 ldm/models/diffusion/plms.py delete mode 100644 ldm/modules/attention.py delete mode 100644 ldm/modules/diffusionmodules/__init__.py delete mode 100644 ldm/modules/diffusionmodules/model.py delete mode 100644 ldm/modules/diffusionmodules/openaimodel.py delete mode 100644 ldm/modules/diffusionmodules/util.py delete mode 100644 ldm/modules/distributions/__init__.py delete mode 100644 ldm/modules/distributions/distributions.py delete mode 100644 ldm/modules/ema.py delete mode 100644 ldm/modules/encoders/__init__.py delete mode 100644 ldm/modules/encoders/modules.py delete mode 100644 ldm/modules/encoders/xlmr.py delete mode 100644 ldm/modules/image_degradation/__init__.py delete mode 100644 ldm/modules/image_degradation/bsrgan.py delete mode 100644 ldm/modules/image_degradation/bsrgan_light.py delete mode 100644 ldm/modules/image_degradation/utils/test.png delete mode 100644 ldm/modules/image_degradation/utils_image.py delete mode 100644 ldm/modules/losses/__init__.py delete mode 100644 ldm/modules/losses/contperceptual.py delete mode 100644 ldm/modules/losses/vqperceptual.py delete mode 100644 ldm/modules/x_transformer.py delete mode 100644 ldm/util.py create mode 100644 modules/xlmr.py diff --git a/configs/altdiffusion/ad-inference.yaml b/configs/altdiffusion/ad-inference.yaml index 1b11b63e..cfbee72d 100644 --- a/configs/altdiffusion/ad-inference.yaml +++ b/configs/altdiffusion/ad-inference.yaml @@ -67,6 +67,6 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.xlmr.BertSeriesModelWithTransformation + target: modules.xlmr.BertSeriesModelWithTransformation params: name: "XLMR-Large" \ No newline at end of file diff --git a/launch.py b/launch.py index ad9ddd5a..3f4dc870 100644 --- a/launch.py +++ b/launch.py @@ -233,11 +233,11 @@ def prepare_enviroment(): os.makedirs(dir_repos, exist_ok=True) - git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) - git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) - git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) - git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) - git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) + git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", ) + git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", ) + git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", ) + git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", ) + git_clone(blip_repo, repo_dir('BLIP'), "BLIP", ) if not is_installed("lpips"): run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ldm/data/base.py b/ldm/data/base.py deleted file mode 100644 index b196c2f7..00000000 --- a/ldm/data/base.py +++ /dev/null @@ -1,23 +0,0 @@ -from abc import abstractmethod -from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset - - -class Txt2ImgIterableBaseDataset(IterableDataset): - ''' - Define an interface to make the IterableDatasets for text2img data chainable - ''' - def __init__(self, num_records=0, valid_ids=None, size=256): - super().__init__() - self.num_records = num_records - self.valid_ids = valid_ids - self.sample_ids = valid_ids - self.size = size - - print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') - - def __len__(self): - return self.num_records - - @abstractmethod - def __iter__(self): - pass \ No newline at end of file diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py deleted file mode 100644 index 1c473f9c..00000000 --- a/ldm/data/imagenet.py +++ /dev/null @@ -1,394 +0,0 @@ -import os, yaml, pickle, shutil, tarfile, glob -import cv2 -import albumentations -import PIL -import numpy as np -import torchvision.transforms.functional as TF -from omegaconf import OmegaConf -from functools import partial -from PIL import Image -from tqdm import tqdm -from torch.utils.data import Dataset, Subset - -import taming.data.utils as tdu -from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve -from taming.data.imagenet import ImagePaths - -from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light - - -def synset2idx(path_to_yaml="data/index_synset.yaml"): - with open(path_to_yaml) as f: - di2s = yaml.load(f) - return dict((v,k) for k,v in di2s.items()) - - -class ImageNetBase(Dataset): - def __init__(self, config=None): - self.config = config or OmegaConf.create() - if not type(self.config)==dict: - self.config = OmegaConf.to_container(self.config) - self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) - self.process_images = True # if False we skip loading & processing images and self.data contains filepaths - self._prepare() - self._prepare_synset_to_human() - self._prepare_idx_to_synset() - self._prepare_human_to_integer_label() - self._load() - - def __len__(self): - return len(self.data) - - def __getitem__(self, i): - return self.data[i] - - def _prepare(self): - raise NotImplementedError() - - def _filter_relpaths(self, relpaths): - ignore = set([ - "n06596364_9591.JPEG", - ]) - relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] - if "sub_indices" in self.config: - indices = str_to_indices(self.config["sub_indices"]) - synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings - self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) - files = [] - for rpath in relpaths: - syn = rpath.split("/")[0] - if syn in synsets: - files.append(rpath) - return files - else: - return relpaths - - def _prepare_synset_to_human(self): - SIZE = 2655750 - URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" - self.human_dict = os.path.join(self.root, "synset_human.txt") - if (not os.path.exists(self.human_dict) or - not os.path.getsize(self.human_dict)==SIZE): - download(URL, self.human_dict) - - def _prepare_idx_to_synset(self): - URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" - self.idx2syn = os.path.join(self.root, "index_synset.yaml") - if (not os.path.exists(self.idx2syn)): - download(URL, self.idx2syn) - - def _prepare_human_to_integer_label(self): - URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" - self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") - if (not os.path.exists(self.human2integer)): - download(URL, self.human2integer) - with open(self.human2integer, "r") as f: - lines = f.read().splitlines() - assert len(lines) == 1000 - self.human2integer_dict = dict() - for line in lines: - value, key = line.split(":") - self.human2integer_dict[key] = int(value) - - def _load(self): - with open(self.txt_filelist, "r") as f: - self.relpaths = f.read().splitlines() - l1 = len(self.relpaths) - self.relpaths = self._filter_relpaths(self.relpaths) - print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) - - self.synsets = [p.split("/")[0] for p in self.relpaths] - self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] - - unique_synsets = np.unique(self.synsets) - class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) - if not self.keep_orig_class_label: - self.class_labels = [class_dict[s] for s in self.synsets] - else: - self.class_labels = [self.synset2idx[s] for s in self.synsets] - - with open(self.human_dict, "r") as f: - human_dict = f.read().splitlines() - human_dict = dict(line.split(maxsplit=1) for line in human_dict) - - self.human_labels = [human_dict[s] for s in self.synsets] - - labels = { - "relpath": np.array(self.relpaths), - "synsets": np.array(self.synsets), - "class_label": np.array(self.class_labels), - "human_label": np.array(self.human_labels), - } - - if self.process_images: - self.size = retrieve(self.config, "size", default=256) - self.data = ImagePaths(self.abspaths, - labels=labels, - size=self.size, - random_crop=self.random_crop, - ) - else: - self.data = self.abspaths - - -class ImageNetTrain(ImageNetBase): - NAME = "ILSVRC2012_train" - URL = "http://www.image-net.org/challenges/LSVRC/2012/" - AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" - FILES = [ - "ILSVRC2012_img_train.tar", - ] - SIZES = [ - 147897477120, - ] - - def __init__(self, process_images=True, data_root=None, **kwargs): - self.process_images = process_images - self.data_root = data_root - super().__init__(**kwargs) - - def _prepare(self): - if self.data_root: - self.root = os.path.join(self.data_root, self.NAME) - else: - cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) - - self.datadir = os.path.join(self.root, "data") - self.txt_filelist = os.path.join(self.root, "filelist.txt") - self.expected_length = 1281167 - self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", - default=True) - if not tdu.is_prepared(self.root): - # prep - print("Preparing dataset {} in {}".format(self.NAME, self.root)) - - datadir = self.datadir - if not os.path.exists(datadir): - path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: - import academictorrents as at - atpath = at.get(self.AT_HASH, datastore=self.root) - assert atpath == path - - print("Extracting {} to {}".format(path, datadir)) - os.makedirs(datadir, exist_ok=True) - with tarfile.open(path, "r:") as tar: - tar.extractall(path=datadir) - - print("Extracting sub-tars.") - subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) - for subpath in tqdm(subpaths): - subdir = subpath[:-len(".tar")] - os.makedirs(subdir, exist_ok=True) - with tarfile.open(subpath, "r:") as tar: - tar.extractall(path=subdir) - - filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) - filelist = [os.path.relpath(p, start=datadir) for p in filelist] - filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" - with open(self.txt_filelist, "w") as f: - f.write(filelist) - - tdu.mark_prepared(self.root) - - -class ImageNetValidation(ImageNetBase): - NAME = "ILSVRC2012_validation" - URL = "http://www.image-net.org/challenges/LSVRC/2012/" - AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" - VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" - FILES = [ - "ILSVRC2012_img_val.tar", - "validation_synset.txt", - ] - SIZES = [ - 6744924160, - 1950000, - ] - - def __init__(self, process_images=True, data_root=None, **kwargs): - self.data_root = data_root - self.process_images = process_images - super().__init__(**kwargs) - - def _prepare(self): - if self.data_root: - self.root = os.path.join(self.data_root, self.NAME) - else: - cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) - self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) - self.datadir = os.path.join(self.root, "data") - self.txt_filelist = os.path.join(self.root, "filelist.txt") - self.expected_length = 50000 - self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", - default=False) - if not tdu.is_prepared(self.root): - # prep - print("Preparing dataset {} in {}".format(self.NAME, self.root)) - - datadir = self.datadir - if not os.path.exists(datadir): - path = os.path.join(self.root, self.FILES[0]) - if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: - import academictorrents as at - atpath = at.get(self.AT_HASH, datastore=self.root) - assert atpath == path - - print("Extracting {} to {}".format(path, datadir)) - os.makedirs(datadir, exist_ok=True) - with tarfile.open(path, "r:") as tar: - tar.extractall(path=datadir) - - vspath = os.path.join(self.root, self.FILES[1]) - if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: - download(self.VS_URL, vspath) - - with open(vspath, "r") as f: - synset_dict = f.read().splitlines() - synset_dict = dict(line.split() for line in synset_dict) - - print("Reorganizing into synset folders") - synsets = np.unique(list(synset_dict.values())) - for s in synsets: - os.makedirs(os.path.join(datadir, s), exist_ok=True) - for k, v in synset_dict.items(): - src = os.path.join(datadir, k) - dst = os.path.join(datadir, v) - shutil.move(src, dst) - - filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) - filelist = [os.path.relpath(p, start=datadir) for p in filelist] - filelist = sorted(filelist) - filelist = "\n".join(filelist)+"\n" - with open(self.txt_filelist, "w") as f: - f.write(filelist) - - tdu.mark_prepared(self.root) - - - -class ImageNetSR(Dataset): - def __init__(self, size=None, - degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., - random_crop=True): - """ - Imagenet Superresolution Dataloader - Performs following ops in order: - 1. crops a crop of size s from image either as random or center crop - 2. resizes crop to size with cv2.area_interpolation - 3. degrades resized crop with degradation_fn - - :param size: resizing to size after cropping - :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light - :param downscale_f: Low Resolution Downsample factor - :param min_crop_f: determines crop size s, - where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) - :param max_crop_f: "" - :param data_root: - :param random_crop: - """ - self.base = self.get_base() - assert size - assert (size / downscale_f).is_integer() - self.size = size - self.LR_size = int(size / downscale_f) - self.min_crop_f = min_crop_f - self.max_crop_f = max_crop_f - assert(max_crop_f <= 1.) - self.center_crop = not random_crop - - self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) - - self.pil_interpolation = False # gets reset later if incase interp_op is from pillow - - if degradation == "bsrgan": - self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) - - elif degradation == "bsrgan_light": - self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) - - else: - interpolation_fn = { - "cv_nearest": cv2.INTER_NEAREST, - "cv_bilinear": cv2.INTER_LINEAR, - "cv_bicubic": cv2.INTER_CUBIC, - "cv_area": cv2.INTER_AREA, - "cv_lanczos": cv2.INTER_LANCZOS4, - "pil_nearest": PIL.Image.NEAREST, - "pil_bilinear": PIL.Image.BILINEAR, - "pil_bicubic": PIL.Image.BICUBIC, - "pil_box": PIL.Image.BOX, - "pil_hamming": PIL.Image.HAMMING, - "pil_lanczos": PIL.Image.LANCZOS, - }[degradation] - - self.pil_interpolation = degradation.startswith("pil_") - - if self.pil_interpolation: - self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) - - else: - self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, - interpolation=interpolation_fn) - - def __len__(self): - return len(self.base) - - def __getitem__(self, i): - example = self.base[i] - image = Image.open(example["file_path_"]) - - if not image.mode == "RGB": - image = image.convert("RGB") - - image = np.array(image).astype(np.uint8) - - min_side_len = min(image.shape[:2]) - crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) - crop_side_len = int(crop_side_len) - - if self.center_crop: - self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) - - else: - self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) - - image = self.cropper(image=image)["image"] - image = self.image_rescaler(image=image)["image"] - - if self.pil_interpolation: - image_pil = PIL.Image.fromarray(image) - LR_image = self.degradation_process(image_pil) - LR_image = np.array(LR_image).astype(np.uint8) - - else: - LR_image = self.degradation_process(image=image)["image"] - - example["image"] = (image/127.5 - 1.0).astype(np.float32) - example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) - - return example - - -class ImageNetSRTrain(ImageNetSR): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def get_base(self): - with open("data/imagenet_train_hr_indices.p", "rb") as f: - indices = pickle.load(f) - dset = ImageNetTrain(process_images=False,) - return Subset(dset, indices) - - -class ImageNetSRValidation(ImageNetSR): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def get_base(self): - with open("data/imagenet_val_hr_indices.p", "rb") as f: - indices = pickle.load(f) - dset = ImageNetValidation(process_images=False,) - return Subset(dset, indices) diff --git a/ldm/data/lsun.py b/ldm/data/lsun.py deleted file mode 100644 index 6256e457..00000000 --- a/ldm/data/lsun.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import numpy as np -import PIL -from PIL import Image -from torch.utils.data import Dataset -from torchvision import transforms - - -class LSUNBase(Dataset): - def __init__(self, - txt_file, - data_root, - size=None, - interpolation="bicubic", - flip_p=0.5 - ): - self.data_paths = txt_file - self.data_root = data_root - with open(self.data_paths, "r") as f: - self.image_paths = f.read().splitlines() - self._length = len(self.image_paths) - self.labels = { - "relative_file_path_": [l for l in self.image_paths], - "file_path_": [os.path.join(self.data_root, l) - for l in self.image_paths], - } - - self.size = size - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] - self.flip = transforms.RandomHorizontalFlip(p=flip_p) - - def __len__(self): - return self._length - - def __getitem__(self, i): - example = dict((k, self.labels[k][i]) for k in self.labels) - image = Image.open(example["file_path_"]) - if not image.mode == "RGB": - image = image.convert("RGB") - - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - crop = min(img.shape[0], img.shape[1]) - h, w, = img.shape[0], img.shape[1] - img = img[(h - crop) // 2:(h + crop) // 2, - (w - crop) // 2:(w + crop) // 2] - - image = Image.fromarray(img) - if self.size is not None: - image = image.resize((self.size, self.size), resample=self.interpolation) - - image = self.flip(image) - image = np.array(image).astype(np.uint8) - example["image"] = (image / 127.5 - 1.0).astype(np.float32) - return example - - -class LSUNChurchesTrain(LSUNBase): - def __init__(self, **kwargs): - super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) - - -class LSUNChurchesValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", - flip_p=flip_p, **kwargs) - - -class LSUNBedroomsTrain(LSUNBase): - def __init__(self, **kwargs): - super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) - - -class LSUNBedroomsValidation(LSUNBase): - def __init__(self, flip_p=0.0, **kwargs): - super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", - flip_p=flip_p, **kwargs) - - -class LSUNCatsTrain(LSUNBase): - def __init__(self, **kwargs): - super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) - - -class LSUNCatsValidation(LSUNBase): - def __init__(self, flip_p=0., **kwargs): - super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", - flip_p=flip_p, **kwargs) diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py deleted file mode 100644 index be39da9c..00000000 --- a/ldm/lr_scheduler.py +++ /dev/null @@ -1,98 +0,0 @@ -import numpy as np - - -class LambdaWarmUpCosineScheduler: - """ - note: use with a base_lr of 1.0 - """ - def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): - self.lr_warm_up_steps = warm_up_steps - self.lr_start = lr_start - self.lr_min = lr_min - self.lr_max = lr_max - self.lr_max_decay_steps = max_decay_steps - self.last_lr = 0. - self.verbosity_interval = verbosity_interval - - def schedule(self, n, **kwargs): - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") - if n < self.lr_warm_up_steps: - lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start - self.last_lr = lr - return lr - else: - t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) - t = min(t, 1.0) - lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi)) - self.last_lr = lr - return lr - - def __call__(self, n, **kwargs): - return self.schedule(n,**kwargs) - - -class LambdaWarmUpCosineScheduler2: - """ - supports repeated iterations, configurable via lists - note: use with a base_lr of 1.0. - """ - def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): - assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) - self.lr_warm_up_steps = warm_up_steps - self.f_start = f_start - self.f_min = f_min - self.f_max = f_max - self.cycle_lengths = cycle_lengths - self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) - self.last_f = 0. - self.verbosity_interval = verbosity_interval - - def find_in_interval(self, n): - interval = 0 - for cl in self.cum_cycles[1:]: - if n <= cl: - return interval - interval += 1 - - def schedule(self, n, **kwargs): - cycle = self.find_in_interval(n) - n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") - if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] - self.last_f = f - return f - else: - t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) - t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( - 1 + np.cos(t * np.pi)) - self.last_f = f - return f - - def __call__(self, n, **kwargs): - return self.schedule(n, **kwargs) - - -class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): - - def schedule(self, n, **kwargs): - cycle = self.find_in_interval(n) - n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}") - - if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] - self.last_f = f - return f - else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) - self.last_f = f - return f - diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py deleted file mode 100644 index 6a9c4f45..00000000 --- a/ldm/models/autoencoder.py +++ /dev/null @@ -1,443 +0,0 @@ -import torch -import pytorch_lightning as pl -import torch.nn.functional as F -from contextlib import contextmanager - -from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer - -from ldm.modules.diffusionmodules.model import Encoder, Decoder -from ldm.modules.distributions.distributions import DiagonalGaussianDistribution - -from ldm.util import instantiate_from_config - - -class VQModel(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - n_embed, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - batch_resize_range=None, - scheduler_config=None, - lr_g_factor=1.0, - remap=None, - sane_index_shape=False, # tell vector quantizer to return indices as bhw - use_ema=False - ): - super().__init__() - self.embed_dim = embed_dim - self.n_embed = n_embed - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, - remap=remap, - sane_index_shape=sane_index_shape) - self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - self.batch_resize_range = batch_resize_range - if self.batch_resize_range is not None: - print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") - - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - self.scheduler_config = scheduler_config - self.lr_g_factor = lr_g_factor - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.parameters()) - self.model_ema.copy_to(self) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: - print(f"Missing Keys: {missing}") - print(f"Unexpected Keys: {unexpected}") - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self) - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - quant, emb_loss, info = self.quantize(h) - return quant, emb_loss, info - - def encode_to_prequant(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, quant): - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - def decode_code(self, code_b): - quant_b = self.quantize.embed_code(code_b) - dec = self.decode(quant_b) - return dec - - def forward(self, input, return_pred_indices=False): - quant, diff, (_,_,ind) = self.encode(input) - dec = self.decode(quant) - if return_pred_indices: - return dec, diff, ind - return dec, diff - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - if self.batch_resize_range is not None: - lower_size = self.batch_resize_range[0] - upper_size = self.batch_resize_range[1] - if self.global_step <= 4: - # do the first few batches with max size to avoid later oom - new_resize = upper_size - else: - new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) - if new_resize != x.shape[2]: - x = F.interpolate(x, size=new_resize, mode="bicubic") - x = x.detach() - return x - - def training_step(self, batch, batch_idx, optimizer_idx): - # https://github.com/pytorch/pytorch/issues/37142 - # try not to fool the heuristics - x = self.get_input(batch, self.image_key) - xrec, qloss, ind = self(x, return_pred_indices=True) - - if optimizer_idx == 0: - # autoencode - aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train", - predicted_indices=ind) - - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return aeloss - - if optimizer_idx == 1: - # discriminator - discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) - return discloss - - def validation_step(self, batch, batch_idx): - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") - return log_dict - - def _validation_step(self, batch, batch_idx, suffix=""): - x = self.get_input(batch, self.image_key) - xrec, qloss, ind = self(x, return_pred_indices=True) - aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) - - discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, - self.global_step, - last_layer=self.get_last_layer(), - split="val"+suffix, - predicted_indices=ind - ) - rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] - self.log(f"val{suffix}/rec_loss", rec_loss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - self.log(f"val{suffix}/aeloss", aeloss, - prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) - if version.parse(pl.__version__) >= version.parse('1.4.0'): - del log_dict_ae[f"val{suffix}/rec_loss"] - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr_d = self.learning_rate - lr_g = self.lr_g_factor*self.learning_rate - print("lr_d", lr_d) - print("lr_g", lr_g) - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quantize.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr_g, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr_d, betas=(0.5, 0.9)) - - if self.scheduler_config is not None: - scheduler = instantiate_from_config(self.scheduler_config) - - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }, - { - 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }, - ] - return [opt_ae, opt_disc], scheduler - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if only_inputs: - log["inputs"] = x - return log - xrec, _ = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["inputs"] = x - log["reconstructions"] = xrec - if plot_ema: - with self.ema_scope(): - xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) - log["reconstructions_ema"] = xrec_ema - return log - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. - return x - - -class VQModelInterface(VQModel): - def __init__(self, embed_dim, *args, **kwargs): - super().__init__(embed_dim=embed_dim, *args, **kwargs) - self.embed_dim = embed_dim - - def encode(self, x): - h = self.encoder(x) - h = self.quant_conv(h) - return h - - def decode(self, h, force_not_quantize=False): - # also go through quantization layer - if not force_not_quantize: - quant, emb_loss, info = self.quantize(h) - else: - quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) - return dec - - -class AutoencoderKL(pl.LightningModule): - def __init__(self, - ddconfig, - lossconfig, - embed_dim, - ckpt_path=None, - ignore_keys=[], - image_key="image", - colorize_nlabels=None, - monitor=None, - ): - super().__init__() - self.image_key = image_key - self.encoder = Encoder(**ddconfig) - self.decoder = Decoder(**ddconfig) - self.loss = instantiate_from_config(lossconfig) - assert ddconfig["double_z"] - self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - if colorize_nlabels is not None: - assert type(colorize_nlabels)==int - self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) - - def init_from_ckpt(self, path, ignore_keys=list()): - sd = torch.load(path, map_location="cpu")["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - self.load_state_dict(sd, strict=False) - print(f"Restored from {path}") - - def encode(self, x): - h = self.encoder(x) - moments = self.quant_conv(h) - posterior = DiagonalGaussianDistribution(moments) - return posterior - - def decode(self, z): - z = self.post_quant_conv(z) - dec = self.decoder(z) - return dec - - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) - if sample_posterior: - z = posterior.sample() - else: - z = posterior.mode() - dec = self.decode(z) - return dec, posterior - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() - return x - - def training_step(self, batch, batch_idx, optimizer_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - - if optimizer_idx == 0: - # train encoder+decoder+logvar - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return aeloss - - if optimizer_idx == 1: - # train the discriminator - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, - last_layer=self.get_last_layer(), split="train") - - self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) - self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) - return discloss - - def validation_step(self, batch, batch_idx): - inputs = self.get_input(batch, self.image_key) - reconstructions, posterior = self(inputs) - aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, - last_layer=self.get_last_layer(), split="val") - - discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, - last_layer=self.get_last_layer(), split="val") - - self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) - self.log_dict(log_dict_ae) - self.log_dict(log_dict_disc) - return self.log_dict - - def configure_optimizers(self): - lr = self.learning_rate - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ - list(self.decoder.parameters())+ - list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] - - def get_last_layer(self): - return self.decoder.conv_out.weight - - @torch.no_grad() - def log_images(self, batch, only_inputs=False, **kwargs): - log = dict() - x = self.get_input(batch, self.image_key) - x = x.to(self.device) - if not only_inputs: - xrec, posterior = self(x) - if x.shape[1] > 3: - # colorize with random projection - assert xrec.shape[1] > 3 - x = self.to_rgb(x) - xrec = self.to_rgb(xrec) - log["samples"] = self.decode(torch.randn_like(posterior.sample())) - log["reconstructions"] = xrec - log["inputs"] = x - return log - - def to_rgb(self, x): - assert self.image_key == "segmentation" - if not hasattr(self, "colorize"): - self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) - x = F.conv2d(x, weight=self.colorize) - x = 2.*(x-x.min())/(x.max()-x.min()) - 1. - return x - - -class IdentityFirstStage(torch.nn.Module): - def __init__(self, *args, vq_interface=False, **kwargs): - self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff - super().__init__() - - def encode(self, x, *args, **kwargs): - return x - - def decode(self, x, *args, **kwargs): - return x - - def quantize(self, x, *args, **kwargs): - if self.vq_interface: - return x, None, [None, None, None] - return x - - def forward(self, x, *args, **kwargs): - return x diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py deleted file mode 100644 index 67e98b9d..00000000 --- a/ldm/models/diffusion/classifier.py +++ /dev/null @@ -1,267 +0,0 @@ -import os -import torch -import pytorch_lightning as pl -from omegaconf import OmegaConf -from torch.nn import functional as F -from torch.optim import AdamW -from torch.optim.lr_scheduler import LambdaLR -from copy import deepcopy -from einops import rearrange -from glob import glob -from natsort import natsorted - -from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel -from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config - -__models__ = { - 'class_label': EncoderUNetModel, - 'segmentation': UNetModel -} - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -class NoisyLatentImageClassifier(pl.LightningModule): - - def __init__(self, - diffusion_path, - num_classes, - ckpt_path=None, - pool='attention', - label_key=None, - diffusion_ckpt_path=None, - scheduler_config=None, - weight_decay=1.e-2, - log_steps=10, - monitor='val/loss', - *args, - **kwargs): - super().__init__(*args, **kwargs) - self.num_classes = num_classes - # get latest config of diffusion model - diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] - self.diffusion_config = OmegaConf.load(diffusion_config).model - self.diffusion_config.params.ckpt_path = diffusion_ckpt_path - self.load_diffusion() - - self.monitor = monitor - self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 - self.log_time_interval = self.diffusion_model.num_timesteps // log_steps - self.log_steps = log_steps - - self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ - else self.diffusion_model.cond_stage_key - - assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' - - if self.label_key not in __models__: - raise NotImplementedError() - - self.load_classifier(ckpt_path, pool) - - self.scheduler_config = scheduler_config - self.use_scheduler = self.scheduler_config is not None - self.weight_decay = weight_decay - - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location="cpu") - if "state_dict" in list(sd.keys()): - sd = sd["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - - def load_diffusion(self): - model = instantiate_from_config(self.diffusion_config) - self.diffusion_model = model.eval() - self.diffusion_model.train = disabled_train - for param in self.diffusion_model.parameters(): - param.requires_grad = False - - def load_classifier(self, ckpt_path, pool): - model_config = deepcopy(self.diffusion_config.params.unet_config.params) - model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels - model_config.out_channels = self.num_classes - if self.label_key == 'class_label': - model_config.pool = pool - - self.model = __models__[self.label_key](**model_config) - if ckpt_path is not None: - print('#####################################################################') - print(f'load from ckpt "{ckpt_path}"') - print('#####################################################################') - self.init_from_ckpt(ckpt_path) - - @torch.no_grad() - def get_x_noisy(self, x, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x)) - continuous_sqrt_alpha_cumprod = None - if self.diffusion_model.use_continuous_noise: - continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) - # todo: make sure t+1 is correct here - - return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, - continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) - - def forward(self, x_noisy, t, *args, **kwargs): - return self.model(x_noisy, t) - - @torch.no_grad() - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') - x = x.to(memory_format=torch.contiguous_format).float() - return x - - @torch.no_grad() - def get_conditioning(self, batch, k=None): - if k is None: - k = self.label_key - assert k is not None, 'Needs to provide label key' - - targets = batch[k].to(self.device) - - if self.label_key == 'segmentation': - targets = rearrange(targets, 'b h w c -> b c h w') - for down in range(self.numd): - h, w = targets.shape[-2:] - targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') - - # targets = rearrange(targets,'b c h w -> b h w c') - - return targets - - def compute_top_k(self, logits, labels, k, reduction="mean"): - _, top_ks = torch.topk(logits, k, dim=1) - if reduction == "mean": - return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() - elif reduction == "none": - return (top_ks == labels[:, None]).float().sum(dim=-1) - - def on_train_epoch_start(self): - # save some memory - self.diffusion_model.model.to('cpu') - - @torch.no_grad() - def write_logs(self, loss, logits, targets): - log_prefix = 'train' if self.training else 'val' - log = {} - log[f"{log_prefix}/loss"] = loss.mean() - log[f"{log_prefix}/acc@1"] = self.compute_top_k( - logits, targets, k=1, reduction="mean" - ) - log[f"{log_prefix}/acc@5"] = self.compute_top_k( - logits, targets, k=5, reduction="mean" - ) - - self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) - self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) - self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) - - def shared_step(self, batch, t=None): - x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) - targets = self.get_conditioning(batch) - if targets.dim() == 4: - targets = targets.argmax(dim=1) - if t is None: - t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() - else: - t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() - x_noisy = self.get_x_noisy(x, t) - logits = self(x_noisy, t) - - loss = F.cross_entropy(logits, targets, reduction='none') - - self.write_logs(loss.detach(), logits.detach(), targets.detach()) - - loss = loss.mean() - return loss, logits, x_noisy, targets - - def training_step(self, batch, batch_idx): - loss, *_ = self.shared_step(batch) - return loss - - def reset_noise_accs(self): - self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in - range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} - - def on_validation_start(self): - self.reset_noise_accs() - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - loss, *_ = self.shared_step(batch) - - for t in self.noisy_acc: - _, logits, _, targets = self.shared_step(batch, t) - self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) - self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) - - return loss - - def configure_optimizers(self): - optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) - - if self.use_scheduler: - scheduler = instantiate_from_config(self.scheduler_config) - - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] - return [optimizer], scheduler - - return optimizer - - @torch.no_grad() - def log_images(self, batch, N=8, *args, **kwargs): - log = dict() - x = self.get_input(batch, self.diffusion_model.first_stage_key) - log['inputs'] = x - - y = self.get_conditioning(batch) - - if self.label_key == 'class_label': - y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) - log['labels'] = y - - if ismap(y): - log['labels'] = self.diffusion_model.to_rgb(y) - - for step in range(self.log_steps): - current_time = step * self.log_time_interval - - _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) - - log[f'inputs@t{current_time}'] = x_noisy - - pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) - pred = rearrange(pred, 'b h w c -> b c h w') - - log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) - - for key in log: - log[key] = log[key][:N] - - return log diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py deleted file mode 100644 index fb31215d..00000000 --- a/ldm/models/diffusion/ddim.py +++ /dev/null @@ -1,241 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm -from functools import partial - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ - extract_into_tensor - - -class DDIMSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return samples, intermediates - - @torch.no_grad() - def ddim_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None,): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - img, pred_x0 = outs - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None): - b, *_, device = *x.shape, x.device - - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - @torch.no_grad() - def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): - # fast, but does not allow for exact reconstruction - # t serves as an index to gather the correct alphas - if use_original_steps: - sqrt_alphas_cumprod = self.sqrt_alphas_cumprod - sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod - else: - sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) - sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas - - if noise is None: - noise = torch.randn_like(x0) - return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + - extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) - - @torch.no_grad() - def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False): - - timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps - timesteps = timesteps[:t_start] - - time_range = np.flip(timesteps) - total_steps = timesteps.shape[0] - print(f"Running DDIM Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='Decoding image', total=total_steps) - x_dec = x_latent - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) - x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - return x_dec \ No newline at end of file diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py deleted file mode 100644 index bbedd04c..00000000 --- a/ldm/models/diffusion/ddpm.py +++ /dev/null @@ -1,1445 +0,0 @@ -""" -wild mixture of -https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py -https://github.com/CompVis/taming-transformers --- merci -""" - -import torch -import torch.nn as nn -import numpy as np -import pytorch_lightning as pl -from torch.optim.lr_scheduler import LambdaLR -from einops import rearrange, repeat -from contextlib import contextmanager -from functools import partial -from tqdm import tqdm -from torchvision.utils import make_grid -from pytorch_lightning.utilities.distributed import rank_zero_only - -from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config -from ldm.modules.ema import LitEma -from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution -from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL -from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like -from ldm.models.diffusion.ddim import DDIMSampler - - -__conditioning_keys__ = {'concat': 'c_concat', - 'crossattn': 'c_crossattn', - 'adm': 'y'} - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def uniform_on_device(r1, r2, shape, device): - return (r1 - r2) * torch.rand(*shape, device=device) + r2 - - -class DDPM(pl.LightningModule): - # classic DDPM with Gaussian diffusion, in image space - def __init__(self, - unet_config, - timesteps=1000, - beta_schedule="linear", - loss_type="l2", - ckpt_path=None, - ignore_keys=[], - load_only_unet=False, - monitor="val/loss", - use_ema=True, - first_stage_key="image", - image_size=256, - channels=3, - log_every_t=100, - clip_denoised=True, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - given_betas=None, - original_elbo_weight=0., - v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta - l_simple_weight=1., - conditioning_key=None, - parameterization="eps", # all assuming fixed variance schedules - scheduler_config=None, - use_positional_encodings=False, - learn_logvar=False, - logvar_init=0., - ): - super().__init__() - assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' - self.parameterization = parameterization - print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") - self.cond_stage_model = None - self.clip_denoised = clip_denoised - self.log_every_t = log_every_t - self.first_stage_key = first_stage_key - self.image_size = image_size # try conv? - self.channels = channels - self.use_positional_encodings = use_positional_encodings - self.model = DiffusionWrapper(unet_config, conditioning_key) - count_params(self.model, verbose=True) - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - self.use_scheduler = scheduler_config is not None - if self.use_scheduler: - self.scheduler_config = scheduler_config - - self.v_posterior = v_posterior - self.original_elbo_weight = original_elbo_weight - self.l_simple_weight = l_simple_weight - - if monitor is not None: - self.monitor = monitor - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) - - self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, - linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) - - self.loss_type = loss_type - - self.learn_logvar = learn_logvar - self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) - if self.learn_logvar: - self.logvar = nn.Parameter(self.logvar, requires_grad=True) - - - def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if exists(given_betas): - betas = given_betas - else: - betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, - cosine_s=cosine_s) - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer('betas', to_torch(betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( - 1. - alphas_cumprod) + self.v_posterior * betas - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - self.register_buffer('posterior_variance', to_torch(posterior_variance)) - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) - self.register_buffer('posterior_mean_coef1', to_torch( - betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) - self.register_buffer('posterior_mean_coef2', to_torch( - (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) - - if self.parameterization == "eps": - lvlb_weights = self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) - elif self.parameterization == "x0": - lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) - else: - raise NotImplementedError("mu not supported") - # TODO how to choose this term - lvlb_weights[0] = lvlb_weights[1] - self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) - assert not torch.isnan(self.lvlb_weights).all() - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): - sd = torch.load(path, map_location="cpu") - if "state_dict" in list(sd.keys()): - sd = sd["state_dict"] - keys = list(sd.keys()) - for k in keys: - for ik in ignore_keys: - if k.startswith(ik): - print("Deleting key {} from state_dict.".format(k)) - del sd[k] - missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( - sd, strict=False) - print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - - def q_mean_variance(self, x_start, t): - """ - Get the distribution q(x_t | x_0). - :param x_start: the [N x C x ...] tensor of noiseless inputs. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :return: A tuple (mean, variance, log_variance), all of x_start's shape. - """ - mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) - variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, x, t, clip_denoised: bool): - model_out = self.model(x, t) - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - if clip_denoised: - x_recon.clamp_(-1., 1.) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def p_sample_loop(self, shape, return_intermediates=False): - device = self.betas.device - b = shape[0] - img = torch.randn(shape, device=device) - intermediates = [img] - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), - clip_denoised=self.clip_denoised) - if i % self.log_every_t == 0 or i == self.num_timesteps - 1: - intermediates.append(img) - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample(self, batch_size=16, return_intermediates=False): - image_size = self.image_size - channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size), - return_intermediates=return_intermediates) - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) - - def get_loss(self, pred, target, mean=True): - if self.loss_type == 'l1': - loss = (target - pred).abs() - if mean: - loss = loss.mean() - elif self.loss_type == 'l2': - if mean: - loss = torch.nn.functional.mse_loss(target, pred) - else: - loss = torch.nn.functional.mse_loss(target, pred, reduction='none') - else: - raise NotImplementedError("unknown loss type '{loss_type}'") - - return loss - - def p_losses(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_out = self.model(x_noisy, t) - - loss_dict = {} - if self.parameterization == "eps": - target = noise - elif self.parameterization == "x0": - target = x_start - else: - raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") - - loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) - - log_prefix = 'train' if self.training else 'val' - - loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) - loss_simple = loss.mean() * self.l_simple_weight - - loss_vlb = (self.lvlb_weights[t] * loss).mean() - loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) - - loss = loss_simple + self.original_elbo_weight * loss_vlb - - loss_dict.update({f'{log_prefix}/loss': loss}) - - return loss, loss_dict - - def forward(self, x, *args, **kwargs): - # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size - # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' - t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() - return self.p_losses(x, t, *args, **kwargs) - - def get_input(self, batch, k): - x = batch[k] - if len(x.shape) == 3: - x = x[..., None] - x = rearrange(x, 'b h w c -> b c h w') - x = x.to(memory_format=torch.contiguous_format).float() - return x - - def shared_step(self, batch): - x = self.get_input(batch, self.first_stage_key) - loss, loss_dict = self(x) - return loss, loss_dict - - def training_step(self, batch, batch_idx): - loss, loss_dict = self.shared_step(batch) - - self.log_dict(loss_dict, prog_bar=True, - logger=True, on_step=True, on_epoch=True) - - self.log("global_step", self.global_step, - prog_bar=True, logger=True, on_step=True, on_epoch=False) - - if self.use_scheduler: - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) - - return loss - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - _, loss_dict_no_ema = self.shared_step(batch) - with self.ema_scope(): - _, loss_dict_ema = self.shared_step(batch) - loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} - self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) - self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self.model) - - def _get_rows_from_list(self, samples): - n_imgs_per_row = len(samples) - denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - @torch.no_grad() - def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() - x = self.get_input(batch, self.first_stage_key) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - x = x.to(self.device)[:N] - log["inputs"] = x - - # get diffusion row - diffusion_row = list() - x_start = x[:n_row] - - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(x_start) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - diffusion_row.append(x_noisy) - - log["diffusion_row"] = self._get_rows_from_list(diffusion_row) - - if sample: - # get denoise row - with self.ema_scope("Plotting"): - samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) - - log["samples"] = samples - log["denoise_row"] = self._get_rows_from_list(denoise_row) - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - if self.learn_logvar: - params = params + [self.logvar] - opt = torch.optim.AdamW(params, lr=lr) - return opt - - -class LatentDiffusion(DDPM): - """main class""" - def __init__(self, - first_stage_config, - cond_stage_config, - num_timesteps_cond=None, - cond_stage_key="image", - cond_stage_trainable=False, - concat_mode=True, - cond_stage_forward=None, - conditioning_key=None, - scale_factor=1.0, - scale_by_std=False, - *args, **kwargs): - self.num_timesteps_cond = default(num_timesteps_cond, 1) - self.scale_by_std = scale_by_std - assert self.num_timesteps_cond <= kwargs['timesteps'] - # for backwards compatibility after implementation of DiffusionWrapper - if conditioning_key is None: - conditioning_key = 'concat' if concat_mode else 'crossattn' - if cond_stage_config == '__is_unconditional__': - conditioning_key = None - ckpt_path = kwargs.pop("ckpt_path", None) - ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) - self.concat_mode = concat_mode - self.cond_stage_trainable = cond_stage_trainable - self.cond_stage_key = cond_stage_key - try: - self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: - self.num_downs = 0 - if not scale_by_std: - self.scale_factor = scale_factor - else: - self.register_buffer('scale_factor', torch.tensor(scale_factor)) - self.instantiate_first_stage(first_stage_config) - self.instantiate_cond_stage(cond_stage_config) - self.cond_stage_forward = cond_stage_forward - self.clip_denoised = False - self.bbox_tokenizer = None - - self.restarted_from_ckpt = False - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys) - self.restarted_from_ckpt = True - - def make_cond_schedule(self, ): - self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) - ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() - self.cond_ids[:self.num_timesteps_cond] = ids - - @rank_zero_only - @torch.no_grad() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - # only for very first batch - if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: - assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' - # set rescale weight to 1./std of encodings - print("### USING STD-RESCALING ###") - x = super().get_input(batch, self.first_stage_key) - x = x.to(self.device) - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - del self.scale_factor - self.register_buffer('scale_factor', 1. / z.flatten().std()) - print(f"setting self.scale_factor to {self.scale_factor}") - print("### USING STD-RESCALING ###") - - def register_schedule(self, - given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) - - self.shorten_cond_schedule = self.num_timesteps_cond > 1 - if self.shorten_cond_schedule: - self.make_cond_schedule() - - def instantiate_first_stage(self, config): - model = instantiate_from_config(config) - self.first_stage_model = model.eval() - self.first_stage_model.train = disabled_train - for param in self.first_stage_model.parameters(): - param.requires_grad = False - - def instantiate_cond_stage(self, config): - if not self.cond_stage_trainable: - if config == "__is_first_stage__": - print("Using first stage also as cond stage.") - self.cond_stage_model = self.first_stage_model - elif config == "__is_unconditional__": - print(f"Training {self.__class__.__name__} as an unconditional model.") - self.cond_stage_model = None - # self.be_unconditional = True - else: - model = instantiate_from_config(config) - self.cond_stage_model = model.eval() - self.cond_stage_model.train = disabled_train - for param in self.cond_stage_model.parameters(): - param.requires_grad = False - else: - assert config != '__is_first_stage__' - assert config != '__is_unconditional__' - model = instantiate_from_config(config) - self.cond_stage_model = model - - def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): - denoise_row = [] - for zd in tqdm(samples, desc=desc): - denoise_row.append(self.decode_first_stage(zd.to(self.device), - force_not_quantize=force_no_decoder_quantization)) - n_imgs_per_row = len(denoise_row) - denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W - denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') - denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') - denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) - return denoise_grid - - def get_first_stage_encoding(self, encoder_posterior): - if isinstance(encoder_posterior, DiagonalGaussianDistribution): - z = encoder_posterior.sample() - elif isinstance(encoder_posterior, torch.Tensor): - z = encoder_posterior - else: - raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") - return self.scale_factor * z - - def get_learned_conditioning(self, c): - if self.cond_stage_forward is None: - if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): - c = self.cond_stage_model.encode(c) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - else: - c = self.cond_stage_model(c) - else: - assert hasattr(self.cond_stage_model, self.cond_stage_forward) - c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) - return c - - def meshgrid(self, h, w): - y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) - x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) - - arr = torch.cat([y, x], dim=-1) - return arr - - def delta_border(self, h, w): - """ - :param h: height - :param w: width - :return: normalized distance to image border, - wtith min distance = 0 at border and max dist = 0.5 at image center - """ - lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) - arr = self.meshgrid(h, w) / lower_right_corner - dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] - dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] - edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] - return edge_dist - - def get_weighting(self, h, w, Ly, Lx, device): - weighting = self.delta_border(h, w) - weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], - self.split_input_params["clip_max_weight"], ) - weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) - - if self.split_input_params["tie_braker"]: - L_weighting = self.delta_border(Ly, Lx) - L_weighting = torch.clip(L_weighting, - self.split_input_params["clip_min_tie_weight"], - self.split_input_params["clip_max_tie_weight"]) - - L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) - weighting = weighting * L_weighting - return weighting - - def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code - """ - :param x: img of size (bs, c, h, w) - :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) - """ - bs, nc, h, w = x.shape - - # number of crops in image - Ly = (h - kernel_size[0]) // stride[0] + 1 - Lx = (w - kernel_size[1]) // stride[1] + 1 - - if uf == 1 and df == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) - unfold = torch.nn.Unfold(**fold_params) - - fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) - - weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) - - elif uf > 1 and df == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) - unfold = torch.nn.Unfold(**fold_params) - - fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), - dilation=1, padding=0, - stride=(stride[0] * uf, stride[1] * uf)) - fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) - - weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) - - elif df > 1 and uf == 1: - fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) - unfold = torch.nn.Unfold(**fold_params) - - fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), - dilation=1, padding=0, - stride=(stride[0] // df, stride[1] // df)) - fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) - - weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) - normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap - weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) - - else: - raise NotImplementedError - - return fold, unfold, normalization, weighting - - @torch.no_grad() - def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, - cond_key=None, return_original_cond=False, bs=None): - x = super().get_input(batch, k) - if bs is not None: - x = x[:bs] - x = x.to(self.device) - encoder_posterior = self.encode_first_stage(x) - z = self.get_first_stage_encoding(encoder_posterior).detach() - - if self.model.conditioning_key is not None: - if cond_key is None: - cond_key = self.cond_stage_key - if cond_key != self.first_stage_key: - if cond_key in ['caption', 'coordinates_bbox']: - xc = batch[cond_key] - elif cond_key == 'class_label': - xc = batch - else: - xc = super().get_input(batch, cond_key).to(self.device) - else: - xc = x - if not self.cond_stage_trainable or force_c_encode: - if isinstance(xc, dict) or isinstance(xc, list): - # import pudb; pudb.set_trace() - c = self.get_learned_conditioning(xc) - else: - c = self.get_learned_conditioning(xc.to(self.device)) - else: - c = xc - if bs is not None: - c = c[:bs] - - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - ckey = __conditioning_keys__[self.model.conditioning_key] - c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} - - else: - c = None - xc = None - if self.use_positional_encodings: - pos_x, pos_y = self.compute_latent_shifts(batch) - c = {'pos_x': pos_x, 'pos_y': pos_y} - out = [z, c] - if return_first_stage_outputs: - xrec = self.decode_first_stage(z) - out.extend([x, xrec]) - if return_original_cond: - out.append(xc) - return out - - @torch.no_grad() - def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() - - z = 1. / self.scale_factor * z - - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [self.first_stage_model.decode(z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize) - for i in range(z.shape[-1])] - else: - - output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - # same as above but without decorator - def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): - if predict_cids: - if z.dim() == 4: - z = torch.argmax(z.exp(), dim=1).long() - z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) - z = rearrange(z, 'b h w c -> b c h w').contiguous() - - z = 1. / self.scale_factor * z - - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - uf = self.split_input_params["vqf"] - bs, nc, h, w = z.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) - - z = unfold(z) # (bn, nc * prod(**ks), L) - # 1. Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): - output_list = [self.first_stage_model.decode(z[:, :, :, :, i], - force_not_quantize=predict_cids or force_not_quantize) - for i in range(z.shape[-1])] - else: - - output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) - o = o * weighting - # Reverse 1. reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization # norm is shape (1, 1, h, w) - return decoded - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - else: - if isinstance(self.first_stage_model, VQModelInterface): - return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) - else: - return self.first_stage_model.decode(z) - - @torch.no_grad() - def encode_first_stage(self, x): - if hasattr(self, "split_input_params"): - if self.split_input_params["patch_distributed_vq"]: - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - df = self.split_input_params["vqf"] - self.split_input_params['original_image_size'] = x.shape[-2:] - bs, nc, h, w = x.shape - if ks[0] > h or ks[1] > w: - ks = (min(ks[0], h), min(ks[1], w)) - print("reducing Kernel") - - if stride[0] > h or stride[1] > w: - stride = (min(stride[0], h), min(stride[1], w)) - print("reducing stride") - - fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) - z = unfold(x) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) - for i in range(z.shape[-1])] - - o = torch.stack(output_list, axis=-1) - o = o * weighting - - # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - decoded = fold(o) - decoded = decoded / normalization - return decoded - - else: - return self.first_stage_model.encode(x) - else: - return self.first_stage_model.encode(x) - - def shared_step(self, batch, **kwargs): - x, c = self.get_input(batch, self.first_stage_key) - loss = self(x, c) - return loss - - def forward(self, x, c, *args, **kwargs): - t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() - if self.model.conditioning_key is not None: - assert c is not None - if self.cond_stage_trainable: - c = self.get_learned_conditioning(c) - if self.shorten_cond_schedule: # TODO: drop this option - tc = self.cond_ids[t].to(self.device) - c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) - return self.p_losses(x, c, t, *args, **kwargs) - - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - - def apply_model(self, x_noisy, t, cond, return_ids=False): - - if isinstance(cond, dict): - # hybrid case, cond is exptected to be a dict - pass - else: - if not isinstance(cond, list): - cond = [cond] - key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' - cond = {key: cond} - - if hasattr(self, "split_input_params"): - assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids - ks = self.split_input_params["ks"] # eg. (128, 128) - stride = self.split_input_params["stride"] # eg. (64, 64) - - h, w = x_noisy.shape[-2:] - - fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) - - z = unfold(x_noisy) # (bn, nc * prod(**ks), L) - # Reshape to img shape - z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] - - if self.cond_stage_key in ["image", "LR_image", "segmentation", - 'bbox_img'] and self.model.conditioning_key: # todo check for completeness - c_key = next(iter(cond.keys())) # get key - c = next(iter(cond.values())) # get value - assert (len(c) == 1) # todo extend to list with more than one elem - c = c[0] # get element - - c = unfold(c) - c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) - - cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] - - elif self.cond_stage_key == 'coordinates_bbox': - assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' - - # assuming padding of unfold is always 0 and its dilation is always 1 - n_patches_per_row = int((w - ks[0]) / stride[0] + 1) - full_img_h, full_img_w = self.split_input_params['original_image_size'] - # as we are operating on latents, we need the factor from the original image size to the - # spatial latent size to properly rescale the crops for regenerating the bbox annotations - num_downs = self.first_stage_model.encoder.num_resolutions - 1 - rescale_latent = 2 ** (num_downs) - - # get top left postions of patches as conforming for the bbbox tokenizer, therefore we - # need to rescale the tl patch coordinates to be in between (0,1) - tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, - rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) - for patch_nr in range(z.shape[-1])] - - # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) - patch_limits = [(x_tl, y_tl, - rescale_latent * ks[0] / full_img_w, - rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] - # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] - - # tokenize crop coordinates for the bounding boxes of the respective patches - patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) - for bbox in patch_limits] # list of length l with tensors of shape (1, 2) - print(patch_limits_tknzd[0].shape) - # cut tknzd crop position from conditioning - assert isinstance(cond, dict), 'cond must be dict to be fed into model' - cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) - print(cut_cond.shape) - - adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) - adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') - print(adapted_cond.shape) - adapted_cond = self.get_learned_conditioning(adapted_cond) - print(adapted_cond.shape) - adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) - print(adapted_cond.shape) - - cond_list = [{'c_crossattn': [e]} for e in adapted_cond] - - else: - cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient - - # apply model by loop over crops - output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] - assert not isinstance(output_list[0], - tuple) # todo cant deal with multiple model outputs check this never happens - - o = torch.stack(output_list, axis=-1) - o = o * weighting - # Reverse reshape to img shape - o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) - # stitch crops together - x_recon = fold(o) / normalization - - else: - x_recon = self.model(x_noisy, t, **cond) - - if isinstance(x_recon, tuple) and not return_ids: - return x_recon[0] - else: - return x_recon - - def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - - def _prior_bpd(self, x_start): - """ - Get the prior KL term for the variational lower-bound, measured in - bits-per-dim. - This term can't be optimized, as it only depends on the encoder. - :param x_start: the [N x C x ...] tensor of inputs. - :return: a batch of [N] KL values (in bits), one per batch element. - """ - batch_size = x_start.shape[0] - t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) - return mean_flat(kl_prior) / np.log(2.0) - - def p_losses(self, x_start, cond, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) - model_output = self.apply_model(x_noisy, t, cond) - - loss_dict = {} - prefix = 'train' if self.training else 'val' - - if self.parameterization == "x0": - target = x_start - elif self.parameterization == "eps": - target = noise - else: - raise NotImplementedError() - - loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) - loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) - - logvar_t = self.logvar[t].to(self.device) - loss = loss_simple / torch.exp(logvar_t) + logvar_t - # loss = loss_simple / torch.exp(self.logvar) + self.logvar - if self.learn_logvar: - loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) - loss_dict.update({'logvar': self.logvar.data.mean()}) - - loss = self.l_simple_weight * loss.mean() - - loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) - loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() - loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) - loss += (self.original_elbo_weight * loss_vlb) - loss_dict.update({f'{prefix}/loss': loss}) - - return loss, loss_dict - - def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, - return_x0=False, score_corrector=None, corrector_kwargs=None): - t_in = t - model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) - - if score_corrector is not None: - assert self.parameterization == "eps" - model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) - - if return_codebook_ids: - model_out, logits = model_out - - if self.parameterization == "eps": - x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) - elif self.parameterization == "x0": - x_recon = model_out - else: - raise NotImplementedError() - - if clip_denoised: - x_recon.clamp_(-1., 1.) - if quantize_denoised: - x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) - if return_codebook_ids: - return model_mean, posterior_variance, posterior_log_variance, logits - elif return_x0: - return model_mean, posterior_variance, posterior_log_variance, x_recon - else: - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, - return_codebook_ids=False, quantize_denoised=False, return_x0=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): - b, *_, device = *x.shape, x.device - outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, - return_codebook_ids=return_codebook_ids, - quantize_denoised=quantize_denoised, - return_x0=return_x0, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) - if return_codebook_ids: - raise DeprecationWarning("Support dropped.") - model_mean, _, model_log_variance, logits = outputs - elif return_x0: - model_mean, _, model_log_variance, x0 = outputs - else: - model_mean, _, model_log_variance = outputs - - noise = noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - - if return_codebook_ids: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) - if return_x0: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 - else: - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - - @torch.no_grad() - def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, - img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., - score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, - log_every_t=None): - if not log_every_t: - log_every_t = self.log_every_t - timesteps = self.num_timesteps - if batch_size is not None: - b = batch_size if batch_size is not None else shape[0] - shape = [batch_size] + list(shape) - else: - b = batch_size = shape[0] - if x_T is None: - img = torch.randn(shape, device=self.device) - else: - img = x_T - intermediates = [] - if cond is not None: - if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} - else: - cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', - total=timesteps) if verbose else reversed( - range(0, timesteps)) - if type(temperature) == float: - temperature = [temperature] * timesteps - - for i in iterator: - ts = torch.full((b,), i, device=self.device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img, x0_partial = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised, return_x0=True, - temperature=temperature[i], noise_dropout=noise_dropout, - score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) - if mask is not None: - assert x0 is not None - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) - return img, intermediates - - @torch.no_grad() - def p_sample_loop(self, cond, shape, return_intermediates=False, - x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, start_T=None, - log_every_t=None): - - if not log_every_t: - log_every_t = self.log_every_t - device = self.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - intermediates = [img] - if timesteps is None: - timesteps = self.num_timesteps - - if start_T is not None: - timesteps = min(timesteps, start_T) - iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( - range(0, timesteps)) - - if mask is not None: - assert x0 is not None - assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match - - for i in iterator: - ts = torch.full((b,), i, device=device, dtype=torch.long) - if self.shorten_cond_schedule: - assert self.model.conditioning_key != 'hybrid' - tc = self.cond_ids[ts].to(cond.device) - cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) - - img = self.p_sample(img, cond, ts, - clip_denoised=self.clip_denoised, - quantize_denoised=quantize_denoised) - if mask is not None: - img_orig = self.q_sample(x0, ts) - img = img_orig * mask + (1. - mask) * img - - if i % log_every_t == 0 or i == timesteps - 1: - intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) - - if return_intermediates: - return img, intermediates - return img - - @torch.no_grad() - def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, - verbose=True, timesteps=None, quantize_denoised=False, - mask=None, x0=None, shape=None,**kwargs): - if shape is None: - shape = (batch_size, self.channels, self.image_size, self.image_size) - if cond is not None: - if isinstance(cond, dict): - cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} - else: - cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] - return self.p_sample_loop(cond, - shape, - return_intermediates=return_intermediates, x_T=x_T, - verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, - mask=mask, x0=x0) - - @torch.no_grad() - def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): - - if ddim: - ddim_sampler = DDIMSampler(self) - shape = (self.channels, self.image_size, self.image_size) - samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, - shape,cond,verbose=False,**kwargs) - - else: - samples, intermediates = self.sample(cond=cond, batch_size=batch_size, - return_intermediates=True,**kwargs) - - return samples, intermediates - - - @torch.no_grad() - def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, - quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, - plot_diffusion_rows=True, **kwargs): - - use_ddim = ddim_steps is not None - - log = dict() - z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, - return_first_stage_outputs=True, - force_c_encode=True, - return_original_cond=True, - bs=N) - N = min(x.shape[0], N) - n_row = min(x.shape[0], n_row) - log["inputs"] = x - log["reconstruction"] = xrec - if self.model.conditioning_key is not None: - if hasattr(self.cond_stage_model, "decode"): - xc = self.cond_stage_model.decode(c) - log["conditioning"] = xc - elif self.cond_stage_key in ["caption"]: - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) - log["conditioning"] = xc - elif self.cond_stage_key == 'class_label': - xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) - log['conditioning'] = xc - elif isimage(xc): - log["conditioning"] = xc - if ismap(xc): - log["original_conditioning"] = self.to_rgb(xc) - - if plot_diffusion_rows: - # get diffusion row - diffusion_row = list() - z_start = z[:n_row] - for t in range(self.num_timesteps): - if t % self.log_every_t == 0 or t == self.num_timesteps - 1: - t = repeat(torch.tensor([t]), '1 -> b', b=n_row) - t = t.to(self.device).long() - noise = torch.randn_like(z_start) - z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) - diffusion_row.append(self.decode_first_stage(z_noisy)) - - diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W - diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') - diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') - diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) - log["diffusion_row"] = diffusion_grid - - if sample: - # get denoise row - with self.ema_scope("Plotting"): - samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, - ddim_steps=ddim_steps,eta=ddim_eta) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) - x_samples = self.decode_first_stage(samples) - log["samples"] = x_samples - if plot_denoise_rows: - denoise_grid = self._get_denoise_row_from_list(z_denoise_row) - log["denoise_row"] = denoise_grid - - if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( - self.first_stage_model, IdentityFirstStage): - # also display when quantizing x0 while sampling - with self.ema_scope("Plotting Quantized Denoised"): - samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, - ddim_steps=ddim_steps,eta=ddim_eta, - quantize_denoised=True) - # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, - # quantize_denoised=True) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_x0_quantized"] = x_samples - - if inpaint: - # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] - mask = torch.ones(N, h, w).to(self.device) - # zeros will be filled in - mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. - mask = mask[:, None, ...] - with self.ema_scope("Plotting Inpaint"): - - samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_inpainting"] = x_samples - log["mask"] = mask - - # outpaint - with self.ema_scope("Plotting Outpaint"): - samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, - ddim_steps=ddim_steps, x0=z[:N], mask=mask) - x_samples = self.decode_first_stage(samples.to(self.device)) - log["samples_outpainting"] = x_samples - - if plot_progressive_rows: - with self.ema_scope("Plotting Progressives"): - img, progressives = self.progressive_denoising(c, - shape=(self.channels, self.image_size, self.image_size), - batch_size=N) - prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") - log["progressive_row"] = prog_row - - if return_keys: - if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: - return log - else: - return {key: log[key] for key in return_keys} - return log - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - if self.cond_stage_trainable: - print(f"{self.__class__.__name__}: Also optimizing conditioner params!") - params = params + list(self.cond_stage_model.parameters()) - if self.learn_logvar: - print('Diffusion model optimizing logvar') - params.append(self.logvar) - opt = torch.optim.AdamW(params, lr=lr) - if self.use_scheduler: - assert 'target' in self.scheduler_config - scheduler = instantiate_from_config(self.scheduler_config) - - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), - 'interval': 'step', - 'frequency': 1 - }] - return [opt], scheduler - return opt - - @torch.no_grad() - def to_rgb(self, x): - x = x.float() - if not hasattr(self, "colorize"): - self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) - x = nn.functional.conv2d(x, weight=self.colorize) - x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. - return x - - -class DiffusionWrapper(pl.LightningModule): - def __init__(self, diff_model_config, conditioning_key): - super().__init__() - self.diffusion_model = instantiate_from_config(diff_model_config) - self.conditioning_key = conditioning_key - assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] - - def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): - if self.conditioning_key is None: - out = self.diffusion_model(x, t) - elif self.conditioning_key == 'concat': - xc = torch.cat([x] + c_concat, dim=1) - out = self.diffusion_model(xc, t) - elif self.conditioning_key == 'crossattn': - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(x, t, context=cc) - elif self.conditioning_key == 'hybrid': - xc = torch.cat([x] + c_concat, dim=1) - cc = torch.cat(c_crossattn, 1) - out = self.diffusion_model(xc, t, context=cc) - elif self.conditioning_key == 'adm': - cc = c_crossattn[0] - out = self.diffusion_model(x, t, y=cc) - else: - raise NotImplementedError() - - return out - - -class Layout2ImgDiffusion(LatentDiffusion): - # TODO: move all layout-specific hacks to this class - def __init__(self, cond_stage_key, *args, **kwargs): - assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) - - def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) - - key = 'train' if self.training else 'validation' - dset = self.trainer.datamodule.datasets[key] - mapper = dset.conditional_builders[self.cond_stage_key] - - bbox_imgs = [] - map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) - for tknzd_bbox in batch[self.cond_stage_key][:N]: - bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) - bbox_imgs.append(bboximg) - - cond_img = torch.stack(bbox_imgs, dim=0) - logs['bbox_image'] = cond_img - return logs diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py deleted file mode 100644 index 7427f38c..00000000 --- a/ldm/models/diffusion/dpm_solver/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py deleted file mode 100644 index bdb64e0c..00000000 --- a/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ /dev/null @@ -1,1184 +0,0 @@ -import torch -import torch.nn.functional as F -import math - - -class NoiseScheduleVP: - def __init__( - self, - schedule='discrete', - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20., - ): - """Create a wrapper class for the forward SDE (VP type). - - *** - Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. - We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. - *** - - The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). - We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). - Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: - - log_alpha_t = self.marginal_log_mean_coeff(t) - sigma_t = self.marginal_std(t) - lambda_t = self.marginal_lambda(t) - - Moreover, as lambda(t) is an invertible function, we also support its inverse function: - - t = self.inverse_lambda(lambda_t) - - =============================================================== - - We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). - - 1. For discrete-time DPMs: - - For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: - t_i = (i + 1) / N - e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. - We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. - - Args: - betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) - alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) - - Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. - - **Important**: Please pay special attention for the args for `alphas_cumprod`: - The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that - q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). - Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have - alpha_{t_n} = \sqrt{\hat{alpha_n}}, - and - log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). - - - 2. For continuous-time DPMs: - - We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise - schedule are the default settings in DDPM and improved-DDPM: - - Args: - beta_min: A `float` number. The smallest beta for the linear schedule. - beta_max: A `float` number. The largest beta for the linear schedule. - cosine_s: A `float` number. The hyperparameter in the cosine schedule. - cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. - T: A `float` number. The ending time of the forward process. - - =============================================================== - - Args: - schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, - 'linear' or 'cosine' for continuous-time DPMs. - Returns: - A wrapper object of the forward SDE (VP type). - - =============================================================== - - Example: - - # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', betas=betas) - - # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) - - # For continuous-time DPMs (VPSDE), linear schedule: - >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) - - """ - - if schedule not in ['discrete', 'linear', 'cosine']: - raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) - - self.schedule = schedule - if schedule == 'discrete': - if betas is not None: - log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) - else: - assert alphas_cumprod is not None - log_alphas = 0.5 * torch.log(alphas_cumprod) - self.total_N = len(log_alphas) - self.T = 1. - self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) - self.log_alpha_array = log_alphas.reshape((1, -1,)) - else: - self.total_N = 1000 - self.beta_0 = continuous_beta_0 - self.beta_1 = continuous_beta_1 - self.cosine_s = 0.008 - self.cosine_beta_max = 999. - self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) - self.schedule = schedule - if schedule == 'cosine': - # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. - # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. - self.T = 0.9946 - else: - self.T = 1. - - def marginal_log_mean_coeff(self, t): - """ - Compute log(alpha_t) of a given continuous-time label t in [0, T]. - """ - if self.schedule == 'discrete': - return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) - elif self.schedule == 'linear': - return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == 'cosine': - log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) - log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 - return log_alpha_t - - def marginal_alpha(self, t): - """ - Compute alpha_t of a given continuous-time label t in [0, T]. - """ - return torch.exp(self.marginal_log_mean_coeff(t)) - - def marginal_std(self, t): - """ - Compute sigma_t of a given continuous-time label t in [0, T]. - """ - return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) - - def marginal_lambda(self, t): - """ - Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. - """ - log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) - return log_mean_coeff - log_std - - def inverse_lambda(self, lamb): - """ - Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. - """ - if self.schedule == 'linear': - tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0**2 + tmp - return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == 'discrete': - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) - t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) - return t.reshape((-1,)) - else: - log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) - t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s - t = t_fn(log_alpha) - return t - - -def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1., - classifier_fn=None, - classifier_kwargs={}, -): - """Create a wrapper function for the noise prediction model. - - DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to - firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. - - We support four types of the diffusion model by setting `model_type`: - - 1. "noise": noise prediction model. (Trained by predicting noise). - - 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). - - 3. "v": velocity prediction model. (Trained by predicting the velocity). - The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. - - [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." - arXiv preprint arXiv:2202.00512 (2022). - [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." - arXiv preprint arXiv:2210.02303 (2022). - - 4. "score": marginal score function. (Trained by denoising score matching). - Note that the score function and the noise prediction model follows a simple relationship: - ``` - noise(x_t, t) = -sigma_t * score(x_t, t) - ``` - - We support three types of guided sampling by DPMs by setting `guidance_type`: - 1. "uncond": unconditional sampling by DPMs. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - - 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - - The input `classifier_fn` has the following format: - `` - classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) - `` - - [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," - in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. - - 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. - The input `model` has the following format: - `` - model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score - `` - And if cond == `unconditional_condition`, the model output is the unconditional DPM output. - - [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." - arXiv preprint arXiv:2207.12598 (2022). - - - The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) - or continuous-time labels (i.e. epsilon to T). - - We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: - `` - def model_fn(x, t_continuous) -> noise: - t_input = get_model_input_time(t_continuous) - return noise_pred(model, x, t_input, **model_kwargs) - `` - where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. - - =============================================================== - - Args: - model: A diffusion model with the corresponding format described above. - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - model_type: A `str`. The parameterization type of the diffusion model. - "noise" or "x_start" or "v" or "score". - model_kwargs: A `dict`. A dict for the other inputs of the model function. - guidance_type: A `str`. The type of the guidance for sampling. - "uncond" or "classifier" or "classifier-free". - condition: A pytorch tensor. The condition for the guided sampling. - Only used for "classifier" or "classifier-free" guidance type. - unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. - Only used for "classifier-free" guidance type. - guidance_scale: A `float`. The scale for the guided sampling. - classifier_fn: A classifier function. Only used for the classifier guidance. - classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. - Returns: - A noise prediction model that accepts the noised data and the continuous time as the inputs. - """ - - def get_model_input_time(t_continuous): - """ - Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. - For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. - For continuous-time DPMs, we just use `t_continuous`. - """ - if noise_schedule.schedule == 'discrete': - return (t_continuous - 1. / noise_schedule.total_N) * 1000. - else: - return t_continuous - - def noise_pred_fn(x, t_continuous, cond=None): - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - t_input = get_model_input_time(t_continuous) - if cond is None: - output = model(x, t_input, **model_kwargs) - else: - output = model(x, t_input, cond, **model_kwargs) - if model_type == "noise": - return output - elif model_type == "x_start": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) - elif model_type == "v": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x - elif model_type == "score": - sigma_t = noise_schedule.marginal_std(t_continuous) - dims = x.dim() - return -expand_dims(sigma_t, dims) * output - - def cond_grad_fn(x, t_input): - """ - Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). - """ - with torch.enable_grad(): - x_in = x.detach().requires_grad_(True) - log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) - return torch.autograd.grad(log_prob.sum(), x_in)[0] - - def model_fn(x, t_continuous): - """ - The noise predicition model function that is used for DPM-Solver. - """ - if t_continuous.reshape((-1,)).shape[0] == 1: - t_continuous = t_continuous.expand((x.shape[0])) - if guidance_type == "uncond": - return noise_pred_fn(x, t_continuous) - elif guidance_type == "classifier": - assert classifier_fn is not None - t_input = get_model_input_time(t_continuous) - cond_grad = cond_grad_fn(x, t_input) - sigma_t = noise_schedule.marginal_std(t_continuous) - noise = noise_pred_fn(x, t_continuous) - return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad - elif guidance_type == "classifier-free": - if guidance_scale == 1. or unconditional_condition is None: - return noise_pred_fn(x, t_continuous, cond=condition) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) - noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) - return noise_uncond + guidance_scale * (noise - noise_uncond) - - assert model_type in ["noise", "x_start", "v"] - assert guidance_type in ["uncond", "classifier", "classifier-free"] - return model_fn - - -class DPM_Solver: - def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): - """Construct a DPM-Solver. - - We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). - If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). - If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). - In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. - The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. - - Args: - model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): - `` - def model_fn(x, t_continuous): - return noise - `` - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. - thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. - max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. - - [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. - """ - self.model = model_fn - self.noise_schedule = noise_schedule - self.predict_x0 = predict_x0 - self.thresholding = thresholding - self.max_val = max_val - - def noise_prediction_fn(self, x, t): - """ - Return the noise prediction model. - """ - return self.model(x, t) - - def data_prediction_fn(self, x, t): - """ - Return the data prediction model (with thresholding). - """ - noise = self.noise_prediction_fn(x, t) - dims = x.dim() - alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) - if self.thresholding: - p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. - s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) - x0 = torch.clamp(x0, -s, s) / s - return x0 - - def model_fn(self, x, t): - """ - Convert the model to the noise prediction model or the data prediction model. - """ - if self.predict_x0: - return self.data_prediction_fn(x, t) - else: - return self.noise_prediction_fn(x, t) - - def get_time_steps(self, skip_type, t_T, t_0, N, device): - """Compute the intermediate time steps for sampling. - - Args: - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - N: A `int`. The total number of the spacing of the time steps. - device: A torch device. - Returns: - A pytorch tensor of the time steps, with the shape (N + 1,). - """ - if skip_type == 'logSNR': - lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) - lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) - logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) - return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == 'time_uniform': - return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == 'time_quadratic': - t_order = 2 - t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) - return t - else: - raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) - - def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): - """ - Get the order of each step for sampling by the singlestep DPM-Solver. - - We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". - Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: - - If order == 1: - We take `steps` of DPM-Solver-1 (i.e. DDIM). - - If order == 2: - - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of DPM-Solver-2. - - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If order == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. - - ============================================ - Args: - order: A `int`. The max order for the solver (2 or 3). - steps: A `int`. The total number of function evaluations (NFE). - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - device: A torch device. - Returns: - orders: A list of the solver order of each step. - """ - if order == 3: - K = steps // 3 + 1 - if steps % 3 == 0: - orders = [3,] * (K - 2) + [2, 1] - elif steps % 3 == 1: - orders = [3,] * (K - 1) + [1] - else: - orders = [3,] * (K - 1) + [2] - elif order == 2: - if steps % 2 == 0: - K = steps // 2 - orders = [2,] * K - else: - K = steps // 2 + 1 - orders = [2,] * (K - 1) + [1] - elif order == 1: - K = 1 - orders = [1,] * steps - else: - raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == 'logSNR': - # To reproduce the results in DPM-Solver paper - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) - else: - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)] - return timesteps_outer, orders - - def denoise_to_zero_fn(self, x, s): - """ - Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. - """ - return self.data_prediction_fn(x, s) - - def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): - """ - DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - if self.predict_x0: - phi_1 = torch.expm1(-h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - ) - if return_intermediate: - return x_t, {'model_s': model_s} - else: - return x_t - else: - phi_1 = torch.expm1(h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - ) - if return_intermediate: - return x_t, {'model_s': model_s} - else: - return x_t - - def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'): - """ - Singlestep solver DPM-Solver-2 from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - r1: A `float`. The hyperparameter of the second-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 0.5 - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) - alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) - - if self.predict_x0: - phi_11 = torch.expm1(-r1 * h) - phi_1 = torch.expm1(-h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s) - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_1 = torch.expm1(h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) - ) - if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1} - else: - return x_t - - def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'): - """ - Singlestep solver DPM-Solver-3 from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - r1: A `float`. The hyperparameter of the third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). - If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 1. / 3. - if r2 is None: - r2 = 2. / 3. - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - lambda_s2 = lambda_s + r2 * h - s1 = ns.inverse_lambda(lambda_s1) - s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) - alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) - - if self.predict_x0: - phi_11 = torch.expm1(-r1 * h) - phi_12 = torch.expm1(-r2 * h) - phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. - phi_2 = phi_1 / h + 1. - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = ( - expand_dims(sigma_s1 / sigma_s, dims) * x - - expand_dims(alpha_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(sigma_s2 / sigma_s, dims) * x - - expand_dims(alpha_s2 * phi_12, dims) * model_s - + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) - ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(sigma_t / sigma_s, dims) * x - - expand_dims(alpha_t * phi_1, dims) * model_s - + expand_dims(alpha_t * phi_2, dims) * D1 - - expand_dims(alpha_t * phi_3, dims) * D2 - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_12 = torch.expm1(r2 * h) - phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. - phi_2 = phi_1 / h - 1. - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = ( - expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x - - expand_dims(sigma_s1 * phi_11, dims) * model_s - ) - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x - - expand_dims(sigma_s2 * phi_12, dims) * model_s - - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) - ) - elif solver_type == 'taylor': - D1_0 = (1. / r1) * (model_s1 - model_s) - D1_1 = (1. / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2. * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x - - expand_dims(sigma_t * phi_1, dims) * model_s - - expand_dims(sigma_t * phi_2, dims) * D1 - - expand_dims(sigma_t * phi_3, dims) * D2 - ) - - if return_intermediate: - return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} - else: - return x_t - - def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): - """ - Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ['dpm_solver', 'taylor']: - raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) - ns = self.noise_schedule - dims = x.dim() - model_prev_1, model_prev_0 = model_prev_list - t_prev_1, t_prev_0 = t_prev_list - lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0 = h_0 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - if self.predict_x0: - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 - ) - else: - if solver_type == 'dpm_solver': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 - ) - elif solver_type == 'taylor': - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 - ) - return x_t - - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): - """ - Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - model_prev_2, model_prev_1, model_prev_0 = model_prev_list - t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_1 = lambda_prev_1 - lambda_prev_2 - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0, r1 = h_0 / h, h_1 / h - D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) - D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) - D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) - D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) - if self.predict_x0: - x_t = ( - expand_dims(sigma_t / sigma_prev_0, dims) * x - - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 - + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 - - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2 - ) - else: - x_t = ( - expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x - - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 - - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 - - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2 - ) - return x_t - - def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None): - """ - Singlestep DPM-Solver with the order `order` from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (x.shape[0],). - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - r1: A `float`. The hyperparameter of the second-order or third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) - elif order == 2: - return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) - elif order == 3: - return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): - """ - Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) - t: A pytorch tensor. The ending time, with the shape (x.shape[0],). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) - elif order == 2: - return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - elif order == 3: - return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'): - """ - The adaptive step size solver based on singlestep DPM-Solver. - - Args: - x: A pytorch tensor. The initial value at time `t_T`. - order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - h_init: A `float`. The initial step size (for logSNR). - atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. - rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. - theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. - t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the - current time and `t_0` is less than `t_err`. The default setting is 1e-5. - solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpm_solver' type. - Returns: - x_0: A pytorch tensor. The approximated solution at time `t_0`. - - [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. - """ - ns = self.noise_schedule - s = t_T * torch.ones((x.shape[0],)).to(x) - lambda_s = ns.marginal_lambda(s) - lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) - h = h_init * torch.ones_like(s).to(x) - x_prev = x - nfe = 0 - if order == 2: - r1 = 0.5 - lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) - elif order == 3: - r1, r2 = 1. / 3., 2. / 3. - lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) - higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) - else: - raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) - while torch.abs((s - t_0)).mean() > t_err: - t = ns.inverse_lambda(lambda_s + h) - x_lower, lower_noise_kwargs = lower_update(x, s, t) - x_higher = higher_update(x, s, t, **lower_noise_kwargs) - delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) - norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) - E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.): - x = x_higher - s = t - x_prev = x_lower - lambda_s = ns.marginal_lambda(s) - h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) - nfe += order - print('adaptive solver nfe', nfe) - return x - - def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', - atol=0.0078, rtol=0.05, - ): - """ - Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. - - ===================================================== - - We support the following algorithms for both noise prediction model and data prediction model: - - 'singlestep': - Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. - We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). - The total number of function evaluations (NFE) == `steps`. - Given a fixed NFE == `steps`, the sampling procedure is: - - If `order` == 1: - - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. - - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If `order` == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. - - 'multistep': - Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. - We initialize the first `order` values by lower order multistep solvers. - Given a fixed NFE == `steps`, the sampling procedure is: - Denote K = steps. - - If `order` == 1: - - We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. - - If `order` == 3: - - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. - - 'singlestep_fixed': - Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). - We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. - - 'adaptive': - Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). - We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. - You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs - (NFE) and the sample quality. - - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. - - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. - - ===================================================== - - Some advices for choosing the algorithm: - - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: - Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, - skip_type='time_uniform', method='singlestep') - - For **guided sampling with large guidance scale** by DPMs: - Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, - skip_type='time_uniform', method='multistep') - - We support three types of `skip_type`: - - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** - - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. - - 'time_quadratic': quadratic time for the time steps. - - ===================================================== - Args: - x: A pytorch tensor. The initial value at time `t_start` - e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. - steps: A `int`. The total number of function evaluations (NFE). - t_start: A `float`. The starting time of the sampling. - If `T` is None, we use self.noise_schedule.T (default is 1.0). - t_end: A `float`. The ending time of the sampling. - If `t_end` is None, we use 1. / self.noise_schedule.total_N. - e.g. if total_N == 1000, we have `t_end` == 1e-3. - For discrete-time DPMs: - - We recommend `t_end` == 1. / self.noise_schedule.total_N. - For continuous-time DPMs: - - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. - order: A `int`. The order of DPM-Solver. - skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. - method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. - denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. - Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). - - This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and - score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID - for diffusion models sampling by diffusion SDEs for low-resolutional images - (such as CIFAR-10). However, we observed that such trick does not matter for - high-resolutional images. As it needs an additional NFE, we do not recommend - it for high-resolutional images. - lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. - Only valid for `method=multistep` and `steps < 15`. We empirically find that - this trick is a key to stabilizing the sampling by DPM-Solver with very few steps - (especially for steps <= 10). So we recommend to set it to be `True`. - solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. - atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - Returns: - x_end: A pytorch tensor. The approximated solution at time `t_end`. - - """ - t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end - t_T = self.noise_schedule.T if t_start is None else t_start - device = x.device - if method == 'adaptive': - with torch.no_grad(): - x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) - elif method == 'multistep': - assert steps >= order - timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - assert timesteps.shape[0] - 1 == steps - with torch.no_grad(): - vec_t = timesteps[0].expand((x.shape[0])) - model_prev_list = [self.model_fn(x, vec_t)] - t_prev_list = [vec_t] - # Init the first `order` values by lower order multistep DPM-Solver. - for init_order in range(1, order): - vec_t = timesteps[init_order].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type) - model_prev_list.append(self.model_fn(x, vec_t)) - t_prev_list.append(vec_t) - # Compute the remaining values by `order`-th order multistep DPM-Solver. - for step in range(order, steps + 1): - vec_t = timesteps[step].expand(x.shape[0]) - if lower_order_final and steps < 15: - step_order = min(order, steps + 1 - step) - else: - step_order = order - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = vec_t - # We do not need to evaluate the final model value. - if step < steps: - model_prev_list[-1] = self.model_fn(x, vec_t) - elif method in ['singlestep', 'singlestep_fixed']: - if method == 'singlestep': - timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) - elif method == 'singlestep_fixed': - K = steps // order - orders = [order,] * K - timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) - for i, order in enumerate(orders): - t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] - timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device) - lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) - vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) - h = lambda_inner[-1] - lambda_inner[0] - r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h - r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h - x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) - if denoise_to_zero: - x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) - return x - - - -############################################################# -# other utility functions -############################################################# - -def interpolate_fn(x, xp, yp): - """ - A piecewise linear function y = f(x), using xp and yp as keypoints. - We implement f(x) in a differentiable way (i.e. applicable for autograd). - The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) - - Args: - x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). - xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. - yp: PyTorch tensor with shape [C, K]. - Returns: - The function values f(x), with shape [N, C]. - """ - N, K = x.shape[0], xp.shape[1] - all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) - sorted_all_x, x_indices = torch.sort(all_x, dim=2) - x_idx = torch.argmin(x_indices, dim=2) - cand_start_idx = x_idx - 1 - start_idx = torch.where( - torch.eq(x_idx, 0), - torch.tensor(1, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) - start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) - end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) - start_idx2 = torch.where( - torch.eq(x_idx, 0), - torch.tensor(0, device=x.device), - torch.where( - torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, - ), - ) - y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) - start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) - end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) - cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) - return cand - - -def expand_dims(v, dims): - """ - Expand the tensor `v` to the dim `dims`. - - Args: - `v`: a PyTorch tensor with shape [N]. - `dim`: a `int`. - Returns: - a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. - """ - return v[(...,) + (None,)*(dims - 1)] \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py deleted file mode 100644 index 2c42d6f9..00000000 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ /dev/null @@ -1,82 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch - -from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver - - -class DPMSolverSampler(object): - def __init__(self, model, **kwargs): - super().__init__() - self.model = model - to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) - self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - - # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') - - device = self.model.betas.device - if x_T is None: - img = torch.randn(size, device=device) - else: - img = x_T - - ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) - - model_fn = model_wrapper( - lambda x, t, c: self.model.apply_model(x, t, c), - ns, - model_type="noise", - guidance_type="classifier-free", - condition=conditioning, - unconditional_condition=unconditional_conditioning, - guidance_scale=unconditional_guidance_scale, - ) - - dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) - - return x.to(device), None diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py deleted file mode 100644 index 78eeb100..00000000 --- a/ldm/models/diffusion/plms.py +++ /dev/null @@ -1,236 +0,0 @@ -"""SAMPLING ONLY.""" - -import torch -import numpy as np -from tqdm import tqdm -from functools import partial - -from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like - - -class PLMSSampler(object): - def __init__(self, model, schedule="linear", **kwargs): - super().__init__() - self.model = model - self.ddpm_num_timesteps = model.num_timesteps - self.schedule = schedule - - def register_buffer(self, name, attr): - if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) - setattr(self, name, attr) - - def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): - if ddim_eta != 0: - raise ValueError('ddim_eta must be 0 for PLMS') - self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, - num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) - alphas_cumprod = self.model.alphas_cumprod - assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' - to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) - - self.register_buffer('betas', to_torch(self.model.betas)) - self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) - self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) - self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) - self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) - self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) - - # ddim sampling parameters - ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), - ddim_timesteps=self.ddim_timesteps, - eta=ddim_eta,verbose=verbose) - self.register_buffer('ddim_sigmas', ddim_sigmas) - self.register_buffer('ddim_alphas', ddim_alphas) - self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) - self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) - sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( - (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( - 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) - self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) - - @torch.no_grad() - def sample(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') - - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return samples, intermediates - - @torch.no_grad() - def plms_sampling(self, cond, shape, - x_T=None, ddim_use_original_steps=False, - callback=None, timesteps=None, quantize_denoised=False, - mask=None, x0=None, img_callback=None, log_every_t=100, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None,): - device = self.model.betas.device - b = shape[0] - if x_T is None: - img = torch.randn(shape, device=device) - else: - img = x_T - - if timesteps is None: - timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps - elif timesteps is not None and not ddim_use_original_steps: - subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 - timesteps = self.ddim_timesteps[:subset_end] - - intermediates = {'x_inter': [img], 'pred_x0': [img]} - time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] - print(f"Running PLMS Sampling with {total_steps} timesteps") - - iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) - old_eps = [] - - for i, step in enumerate(iterator): - index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) - ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) - - if mask is not None: - assert x0 is not None - img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? - img = img_orig * mask + (1. - mask) * img - - outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, - quantize_denoised=quantize_denoised, temperature=temperature, - noise_dropout=noise_dropout, score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - old_eps=old_eps, t_next=ts_next) - img, pred_x0, e_t = outs - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - - return img, intermediates - - @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py deleted file mode 100644 index f4eff39c..00000000 --- a/ldm/modules/attention.py +++ /dev/null @@ -1,261 +0,0 @@ -from inspect import isfunction -import math -import torch -import torch.nn.functional as F -from torch import nn, einsum -from einops import rearrange, repeat - -from ldm.modules.diffusionmodules.util import checkpoint - - -def exists(val): - return val is not None - - -def uniq(arr): - return{el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) - context = torch.einsum('bhdn,bhen->bhde', k, v) - out = torch.einsum('bhde,bhdn->bhen', context, q) - out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) - return self.to_out(out) - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) - - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) - h_ = self.proj_out(h_) - - return x+h_ - - -class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.scale = dim_head ** -0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), - nn.Dropout(dropout) - ) - - def forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) - - -class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): - super().__init__() - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, - heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def forward(self, x, context=None): - return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) - - def _forward(self, x, context=None): - x = self.attn1(self.norm1(x)) + x - x = self.attn2(self.norm2(x), context=context) + x - x = self.ff(self.norm3(x)) + x - return x - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - """ - def __init__(self, in_channels, n_heads, d_head, - depth=1, dropout=0., context_dim=None): - super().__init__() - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - - self.proj_in = nn.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0) - - self.transformer_blocks = nn.ModuleList( - [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth)] - ) - - self.proj_out = zero_module(nn.Conv2d(inner_dim, - in_channels, - kernel_size=1, - stride=1, - padding=0)) - - def forward(self, x, context=None): - # note: if no context is given, cross-attention defaults to self-attention - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - x = self.proj_in(x) - x = rearrange(x, 'b c h w -> b (h w) c') - for block in self.transformer_blocks: - x = block(x, context=context) - x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) - x = self.proj_out(x) - return x + x_in \ No newline at end of file diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py deleted file mode 100644 index 533e589a..00000000 --- a/ldm/modules/diffusionmodules/model.py +++ /dev/null @@ -1,835 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import math -import torch -import torch.nn as nn -import numpy as np -from einops import rearrange - -from ldm.util import instantiate_from_config -from ldm.modules.attention import LinearAttention - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0,1,0,0)) - return emb - - -def nonlinearity(x): - # swish - return x*torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=3, - stride=2, - padding=0) - - def forward(self, x): - if self.with_conv: - pad = (0,1,0,1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, - out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d(out_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - else: - self.nin_shortcut = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x+h - - -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) - - h_ = self.proj_out(h_) - - return x+h_ - - -def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' - print(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - return AttnBlock(in_channels) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - return LinAttnBlock(in_channels) - - -class Model(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = self.ch*4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList([ - torch.nn.Linear(self.ch, - self.temb_ch), - torch.nn.Linear(self.temb_ch, - self.temb_ch), - ]) - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - skip_in = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - if i_block == self.num_res_blocks: - skip_in = ch*in_ch_mult[i_level] - block.append(ResnetBlock(in_channels=block_in+skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x, t=None, context=None): - #assert x.shape[2] == x.shape[3] == self.resolution - if context is not None: - # assume aligned context, cat along channel axis - x = torch.cat((x, context), dim=1) - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - def get_last_layer(self): - return self.conv_out.weight - - -class Encoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", - **ignore_kwargs): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d(in_channels, - self.ch, - kernel_size=3, - stride=1, - padding=1) - - curr_res = resolution - in_ch_mult = (1,)+tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch*in_ch_mult[i_level] - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions-1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - 2*z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions-1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, - attn_type="vanilla", **ignorekwargs): - super().__init__() - if use_linear_attn: attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,)+tuple(ch_mult) - block_in = ch*ch_mult[self.num_resolutions-1] - curr_res = resolution // 2**(self.num_resolutions-1) - self.z_shape = (1,z_channels,curr_res,curr_res) - print("Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape))) - - # z to block_in - self.conv_in = torch.nn.Conv2d(z_channels, - block_in, - kernel_size=3, - stride=1, - padding=1) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock(in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch*ch_mult[i_level] - for i_block in range(self.num_res_blocks+1): - block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_ch, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, z): - #assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - if self.tanh_out: - h = torch.tanh(h) - return h - - -class SimpleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, *args, **kwargs): - super().__init__() - self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), - ResnetBlock(in_channels=in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=2 * in_channels, - out_channels=4 * in_channels, - temb_channels=0, dropout=0.0), - ResnetBlock(in_channels=4 * in_channels, - out_channels=2 * in_channels, - temb_channels=0, dropout=0.0), - nn.Conv2d(2*in_channels, in_channels, 1), - Upsample(in_channels, with_conv=True)]) - # end - self.norm_out = Normalize(in_channels) - self.conv_out = torch.nn.Conv2d(in_channels, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - for i, layer in enumerate(self.model): - if i in [1,2,3]: - x = layer(x, None) - else: - x = layer(x) - - h = self.norm_out(x) - h = nonlinearity(h) - x = self.conv_out(h) - return x - - -class UpsampleDecoder(nn.Module): - def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, - ch_mult=(2,2), dropout=0.0): - super().__init__() - # upsampling - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - block_in = in_channels - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.res_blocks = nn.ModuleList() - self.upsample_blocks = nn.ModuleList() - for i_level in range(self.num_resolutions): - res_block = [] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - res_block.append(ResnetBlock(in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout)) - block_in = block_out - self.res_blocks.append(nn.ModuleList(res_block)) - if i_level != self.num_resolutions - 1: - self.upsample_blocks.append(Upsample(block_in, True)) - curr_res = curr_res * 2 - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d(block_in, - out_channels, - kernel_size=3, - stride=1, - padding=1) - - def forward(self, x): - # upsampling - h = x - for k, i_level in enumerate(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.res_blocks[i_level][i_block](h, None) - if i_level != self.num_resolutions - 1: - h = self.upsample_blocks[k](h) - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class LatentRescaler(nn.Module): - def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): - super().__init__() - # residual block, interpolate, residual block - self.factor = factor - self.conv_in = nn.Conv2d(in_channels, - mid_channels, - kernel_size=3, - stride=1, - padding=1) - self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) - self.attn = AttnBlock(mid_channels) - self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, - out_channels=mid_channels, - temb_channels=0, - dropout=0.0) for _ in range(depth)]) - - self.conv_out = nn.Conv2d(mid_channels, - out_channels, - kernel_size=1, - ) - - def forward(self, x): - x = self.conv_in(x) - for block in self.res_block1: - x = block(x, None) - x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) - x = self.attn(x) - for block in self.res_block2: - x = block(x, None) - x = self.conv_out(x) - return x - - -class MergedRescaleEncoder(nn.Module): - def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, - attn_resolutions, dropout=0.0, resamp_with_conv=True, - ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): - super().__init__() - intermediate_chn = ch * ch_mult[-1] - self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, - z_channels=intermediate_chn, double_z=False, resolution=resolution, - attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, - out_ch=None) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, - mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) - - def forward(self, x): - x = self.encoder(x) - x = self.rescaler(x) - return x - - -class MergedRescaleDecoder(nn.Module): - def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), - dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): - super().__init__() - tmp_chn = z_channels*ch_mult[-1] - self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, - resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, - ch_mult=ch_mult, resolution=resolution, ch=ch) - self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, - out_channels=tmp_chn, depth=rescale_module_depth) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Upsampler(nn.Module): - def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): - super().__init__() - assert out_size >= in_size - num_blocks = int(np.log2(out_size//in_size))+1 - factor_up = 1.+ (out_size % in_size) - print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") - self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, - out_channels=in_channels) - self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, - attn_resolutions=[], in_channels=None, ch=in_channels, - ch_mult=[ch_mult for _ in range(num_blocks)]) - - def forward(self, x): - x = self.rescaler(x) - x = self.decoder(x) - return x - - -class Resize(nn.Module): - def __init__(self, in_channels=None, learned=False, mode="bilinear"): - super().__init__() - self.with_conv = learned - self.mode = mode - if self.with_conv: - print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") - raise NotImplementedError() - assert in_channels is not None - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=4, - stride=2, - padding=1) - - def forward(self, x, scale_factor=1.0): - if scale_factor==1.0: - return x - else: - x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) - return x - -class FirstStagePostProcessor(nn.Module): - - def __init__(self, ch_mult:list, in_channels, - pretrained_model:nn.Module=None, - reshape=False, - n_channels=None, - dropout=0., - pretrained_config=None): - super().__init__() - if pretrained_config is None: - assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' - self.pretrained_model = pretrained_model - else: - assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' - self.instantiate_pretrained(pretrained_config) - - self.do_reshape = reshape - - if n_channels is None: - n_channels = self.pretrained_model.encoder.ch - - self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) - self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, - stride=1,padding=1) - - blocks = [] - downs = [] - ch_in = n_channels - for m in ch_mult: - blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) - ch_in = m * n_channels - downs.append(Downsample(ch_in, with_conv=False)) - - self.model = nn.ModuleList(blocks) - self.downsampler = nn.ModuleList(downs) - - - def instantiate_pretrained(self, config): - model = instantiate_from_config(config) - self.pretrained_model = model.eval() - # self.pretrained_model.train = False - for param in self.pretrained_model.parameters(): - param.requires_grad = False - - - @torch.no_grad() - def encode_with_pretrained(self,x): - c = self.pretrained_model.encode(x) - if isinstance(c, DiagonalGaussianDistribution): - c = c.mode() - return c - - def forward(self,x): - z_fs = self.encode_with_pretrained(x) - z = self.proj_norm(z_fs) - z = self.proj(z) - z = nonlinearity(z) - - for submodel, downmodel in zip(self.model,self.downsampler): - z = submodel(z,temb=None) - z = downmodel(z) - - if self.do_reshape: - z = rearrange(z,'b c h w -> b (h w) c') - return z - diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py deleted file mode 100644 index fcf95d1e..00000000 --- a/ldm/modules/diffusionmodules/openaimodel.py +++ /dev/null @@ -1,961 +0,0 @@ -from abc import abstractmethod -from functools import partial -import math -from typing import Iterable - -import numpy as np -import torch as th -import torch.nn as nn -import torch.nn.functional as F - -from ldm.modules.diffusionmodules.util import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) -from ldm.modules.attention import SpatialTransformer - - -# dummy replace -def convert_module_to_f16(x): - pass - -def convert_module_to_f32(x): - pass - - -## go -class AttentionPool2d(nn.Module): - """ - Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py - """ - - def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: int = None, - ): - super().__init__() - self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) - self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) - self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) - self.num_heads = embed_dim // num_heads_channels - self.attention = QKVAttention(self.num_heads) - - def forward(self, x): - b, c, *_spatial = x.shape - x = x.reshape(b, c, -1) # NC(HW) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) - x = self.qkv_proj(x) - x = self.attention(x) - x = self.c_proj(x) - return x[:, :, 0] - - -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ - - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb, context=None): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, SpatialTransformer): - x = layer(x, context) - else: - x = layer(x) - return x - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - -class TransposedUpsample(nn.Module): - 'Learned 2x upsampling without padding' - def __init__(self, channels, out_channels=None, ks=5): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - - self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) - - def forward(self,x): - return self.up(x) - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=padding - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - - - def _forward(self, x, emb): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - if use_new_attention_order: - # split qkv before split heads - self.attention = QKVAttention(self.num_heads) - else: - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! - #return pt_checkpoint(self._forward, x) # pytorch - - def _forward(self, x): - b, c, *spatial = x.shape - x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) - h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) - - -def count_flops_attn(model, _x, y): - """ - A counter for the `thop` package to count the operations in an - attention operation. - Meant to be used like: - macs, params = thop.profile( - model, - inputs=(inputs, timestamps), - custom_ops={QKVAttention: QKVAttention.count_flops}, - ) - """ - b, c, *spatial = y[0].shape - num_spatial = int(np.prod(spatial)) - # We perform two matmuls with the same number of ops. - # The first computes the weight matrix, the second computes - # the combination of the value vectors. - matmul_ops = 2 * b * (num_spatial ** 2) * c - model.total_ops += th.DoubleTensor([matmul_ops]) - - -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention and splits in a different order. - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class UNetModel(nn.Module): - """ - The full UNet model with attention and timestep embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - num_classes=None, - use_checkpoint=False, - use_fp16=False, - num_heads=-1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - use_spatial_transformer=False, # custom transformer support - transformer_depth=1, # custom transformer support - context_dim=None, # custom transformer support - n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model - legacy=True, - ): - super().__init__() - if use_spatial_transformer: - assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' - - if context_dim is not None: - assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' - from omegaconf.listconfig import ListConfig - if type(context_dim) == ListConfig: - context_dim = list(context_dim) - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' - - if num_head_channels == -1: - assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' - - self.image_size = image_size - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.predict_codebook_ids = n_embed is not None - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - if legacy: - #num_heads = 1 - dim_head = ch // num_heads if use_spatial_transformer else num_head_channels - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) if not use_spatial_transformer else SpatialTransformer( - ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim - ) - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - if self.predict_codebook_ids: - self.id_predictor = nn.Sequential( - normalization(ch), - conv_nd(dims, model_channels, n_embed, 1), - #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits - ) - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - self.output_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - self.output_blocks.apply(convert_module_to_f32) - - def forward(self, x, timesteps=None, context=None, y=None,**kwargs): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional. - :return: an [N x C x ...] Tensor of outputs. - """ - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - if self.predict_codebook_ids: - return self.id_predictor(h) - else: - return self.out(h) - - -class EncoderUNetModel(nn.Module): - """ - The half UNet model with attention and timestep embedding. - For usage, see UNet. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - pool="adaptive", - *args, - **kwargs - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - self.pool = pool - if pool == "adaptive": - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - nn.AdaptiveAvgPool2d((1, 1)), - zero_module(conv_nd(dims, ch, out_channels, 1)), - nn.Flatten(), - ) - elif pool == "attention": - assert num_head_channels != -1 - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), - ) - elif pool == "spatial": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - nn.ReLU(), - nn.Linear(2048, self.out_channels), - ) - elif pool == "spatial_v2": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - normalization(2048), - nn.SiLU(), - nn.Linear(2048, self.out_channels), - ) - else: - raise NotImplementedError(f"Unexpected {pool} pooling") - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - - def forward(self, x, timesteps): - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :return: an [N x K] Tensor of outputs. - """ - emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - - results = [] - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = self.middle_block(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = th.cat(results, axis=-1) - return self.out(h) - else: - h = h.type(x.dtype) - return self.out(h) - diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py deleted file mode 100644 index a952e6c4..00000000 --- a/ldm/modules/diffusionmodules/util.py +++ /dev/null @@ -1,267 +0,0 @@ -# adopted from -# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -# and -# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -# and -# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py -# -# thanks! - - -import os -import math -import torch -import torch.nn as nn -import numpy as np -from einops import repeat - -from ldm.util import instantiate_from_config - - -def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - if schedule == "linear": - betas = ( - torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 - ) - - elif schedule == "cosine": - timesteps = ( - torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s - ) - alphas = timesteps / (1 + cosine_s) * np.pi / 2 - alphas = torch.cos(alphas).pow(2) - alphas = alphas / alphas[0] - betas = 1 - alphas[1:] / alphas[:-1] - betas = np.clip(betas, a_min=0, a_max=0.999) - - elif schedule == "sqrt_linear": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) - elif schedule == "sqrt": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 - else: - raise ValueError(f"schedule '{schedule}' unknown.") - return betas.numpy() - - -def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): - if ddim_discr_method == 'uniform': - c = num_ddpm_timesteps // num_ddim_timesteps - ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) - else: - raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') - - # assert ddim_timesteps.shape[0] == num_ddim_timesteps - # add one to get the final alpha values right (the ones from first scale to data during sampling) - steps_out = ddim_timesteps + 1 - if verbose: - print(f'Selected timesteps for ddim sampler: {steps_out}') - return steps_out - - -def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): - # select alphas for computing the variance schedule - alphas = alphacums[ddim_timesteps] - alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) - - # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) - if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' - f'this results in the following sigma_t schedule for ddim sampler {sigmas}') - return sigmas, alphas, alphas_prev - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -def extract_into_tensor(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - else: - embedding = repeat(timesteps, 'b -> b d', d=dim) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def normalization(channels): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class HybridConditioner(nn.Module): - - def __init__(self, c_concat_config, c_crossattn_config): - super().__init__() - self.concat_conditioner = instantiate_from_config(c_concat_config) - self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) - - def forward(self, c_concat, c_crossattn): - c_concat = self.concat_conditioner(c_concat) - c_crossattn = self.crossattn_conditioner(c_crossattn) - return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} - - -def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py deleted file mode 100644 index f2b8ef90..00000000 --- a/ldm/modules/distributions/distributions.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import numpy as np - - -class AbstractDistribution: - def sample(self): - raise NotImplementedError() - - def mode(self): - raise NotImplementedError() - - -class DiracDistribution(AbstractDistribution): - def __init__(self, value): - self.value = value - - def sample(self): - return self.value - - def mode(self): - return self.value - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.]) - else: - if other is None: - return 0.5 * torch.sum(torch.pow(self.mean, 2) - + self.var - 1.0 - self.logvar, - dim=[1, 2, 3]) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - 1.0 - self.logvar + other.logvar, - dim=[1, 2, 3]) - - def nll(self, sample, dims=[1,2,3]): - if self.deterministic: - return torch.Tensor([0.]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims) - - def mode(self): - return self.mean - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 - Compute the KL divergence between two gaussians. - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] - - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) - ) diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py deleted file mode 100644 index c8c75af4..00000000 --- a/ldm/modules/ema.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -from torch import nn - - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): - super().__init__() - if decay < 0.0 or decay > 1.0: - raise ValueError('Decay must be between 0 and 1') - - self.m_name2s_name = {} - self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) - self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates - else torch.tensor(-1,dtype=torch.int)) - - for name, p in model.named_parameters(): - if p.requires_grad: - #remove as '.'-character is not allowed in buffers - s_name = name.replace('.','') - self.m_name2s_name.update({name:s_name}) - self.register_buffer(s_name,p.clone().detach().data) - - self.collected_params = [] - - def forward(self,model): - decay = self.decay - - if self.num_updates >= 0: - self.num_updates += 1 - decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py deleted file mode 100644 index ededbe43..00000000 --- a/ldm/modules/encoders/modules.py +++ /dev/null @@ -1,234 +0,0 @@ -import torch -import torch.nn as nn -from functools import partial -import clip -from einops import rearrange, repeat -from transformers import CLIPTokenizer, CLIPTextModel -import kornia - -from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test - - -class AbstractEncoder(nn.Module): - def __init__(self): - super().__init__() - - def encode(self, *args, **kwargs): - raise NotImplementedError - - - -class ClassEmbedder(nn.Module): - def __init__(self, embed_dim, n_classes=1000, key='class'): - super().__init__() - self.key = key - self.embedding = nn.Embedding(n_classes, embed_dim) - - def forward(self, batch, key=None): - if key is None: - key = self.key - # this is for use in crossattn - c = batch[key][:, None] - c = self.embedding(c) - return c - - -class TransformerEmbedder(AbstractEncoder): - """Some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): - super().__init__() - self.device = device - self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, - attn_layers=Encoder(dim=n_embed, depth=n_layer)) - - def forward(self, tokens): - tokens = tokens.to(self.device) # meh - z = self.transformer(tokens, return_embeddings=True) - return z - - def encode(self, x): - return self(x) - - -class BERTTokenizer(AbstractEncoder): - """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" - def __init__(self, device="cuda", vq_interface=True, max_length=77): - super().__init__() - from transformers import BertTokenizerFast # TODO: add to reuquirements - self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") - self.device = device - self.vq_interface = vq_interface - self.max_length = max_length - - def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) - return tokens - - @torch.no_grad() - def encode(self, text): - tokens = self(text) - if not self.vq_interface: - return tokens - return None, None, [None, None, tokens] - - def decode(self, text): - return text - - -class BERTEmbedder(AbstractEncoder): - """Uses the BERT tokenizr model and add some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, - device="cuda",use_tokenizer=True, embedding_dropout=0.0): - super().__init__() - self.use_tknz_fn = use_tokenizer - if self.use_tknz_fn: - self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) - self.device = device - self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, - attn_layers=Encoder(dim=n_embed, depth=n_layer), - emb_dropout=embedding_dropout) - - def forward(self, text): - if self.use_tknz_fn: - tokens = self.tknz_fn(text)#.to(self.device) - else: - tokens = text - z = self.transformer(tokens, return_embeddings=True) - return z - - def encode(self, text): - # output of length 77 - return self(text) - - -class SpatialRescaler(nn.Module): - def __init__(self, - n_stages=1, - method='bilinear', - multiplier=0.5, - in_channels=3, - out_channels=None, - bias=False): - super().__init__() - self.n_stages = n_stages - assert self.n_stages >= 0 - assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] - self.multiplier = multiplier - self.interpolator = partial(torch.nn.functional.interpolate, mode=method) - self.remap_output = out_channels is not None - if self.remap_output: - print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') - self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) - - def forward(self,x): - for stage in range(self.n_stages): - x = self.interpolator(x, scale_factor=self.multiplier) - - - if self.remap_output: - x = self.channel_mapper(x) - return x - - def encode(self, x): - return self(x) - -class FrozenCLIPEmbedder(AbstractEncoder): - """Uses the CLIP transformer encoder for text (from Hugging Face)""" - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): - super().__init__() - self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained(version) - self.device = device - self.max_length = max_length - self.freeze() - - def freeze(self): - self.transformer = self.transformer.eval() - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, - return_overflowing_tokens=False, padding="max_length", return_tensors="pt") - tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer(input_ids=tokens) - - z = outputs.last_hidden_state - return z - - def encode(self, text): - return self(text) - - -class FrozenCLIPTextEmbedder(nn.Module): - """ - Uses the CLIP transformer encoder for text. - """ - def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): - super().__init__() - self.model, _ = clip.load(version, jit=False, device="cpu") - self.device = device - self.max_length = max_length - self.n_repeat = n_repeat - self.normalize = normalize - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - tokens = clip.tokenize(text).to(self.device) - z = self.model.encode_text(tokens) - if self.normalize: - z = z / torch.linalg.norm(z, dim=1, keepdim=True) - return z - - def encode(self, text): - z = self(text) - if z.ndim==2: - z = z[:, None, :] - z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) - return z - - -class FrozenClipImageEmbedder(nn.Module): - """ - Uses the CLIP image encoder. - """ - def __init__( - self, - model, - jit=False, - device='cuda' if torch.cuda.is_available() else 'cpu', - antialias=False, - ): - super().__init__() - self.model, _ = clip.load(name=model, device=device, jit=jit) - - self.antialias = antialias - - self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) - self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) - - def preprocess(self, x): - # normalize to [0,1] - x = kornia.geometry.resize(x, (224, 224), - interpolation='bicubic',align_corners=True, - antialias=self.antialias) - x = (x + 1.) / 2. - # renormalize according to clip - x = kornia.enhance.normalize(x, self.mean, self.std) - return x - - def forward(self, x): - # x is assumed to be in range [-1,1] - return self.model.encode_image(self.preprocess(x)) - - -if __name__ == "__main__": - from ldm.util import count_params - model = FrozenCLIPEmbedder() - count_params(model, verbose=True) \ No newline at end of file diff --git a/ldm/modules/encoders/xlmr.py b/ldm/modules/encoders/xlmr.py deleted file mode 100644 index beab3fdf..00000000 --- a/ldm/modules/encoders/xlmr.py +++ /dev/null @@ -1,137 +0,0 @@ -from transformers import BertPreTrainedModel,BertModel,BertConfig -import torch.nn as nn -import torch -from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig -from transformers import XLMRobertaModel,XLMRobertaTokenizer -from typing import Optional - -class BertSeriesConfig(BertConfig): - def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): - - super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) - self.project_dim = project_dim - self.pooler_fn = pooler_fn - self.learn_encoder = learn_encoder - -class RobertaSeriesConfig(XLMRobertaConfig): - def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): - super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) - self.project_dim = project_dim - self.pooler_fn = pooler_fn - self.learn_encoder = learn_encoder - - -class BertSeriesModelWithTransformation(BertPreTrainedModel): - - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - config_class = BertSeriesConfig - - def __init__(self, config=None, **kargs): - # modify initialization for autoloading - if config is None: - config = XLMRobertaConfig() - config.attention_probs_dropout_prob= 0.1 - config.bos_token_id=0 - config.eos_token_id=2 - config.hidden_act='gelu' - config.hidden_dropout_prob=0.1 - config.hidden_size=1024 - config.initializer_range=0.02 - config.intermediate_size=4096 - config.layer_norm_eps=1e-05 - config.max_position_embeddings=514 - - config.num_attention_heads=16 - config.num_hidden_layers=24 - config.output_past=True - config.pad_token_id=1 - config.position_embedding_type= "absolute" - - config.type_vocab_size= 1 - config.use_cache=True - config.vocab_size= 250002 - config.project_dim = 768 - config.learn_encoder = False - super().__init__(config) - self.roberta = XLMRobertaModel(config) - self.transformation = nn.Linear(config.hidden_size,config.project_dim) - self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') - self.pooler = lambda x: x[:,0] - self.post_init() - - def encode(self,c): - device = next(self.parameters()).device - text = self.tokenizer(c, - truncation=True, - max_length=77, - return_length=False, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt") - text["input_ids"] = torch.tensor(text["input_ids"]).to(device) - text["attention_mask"] = torch.tensor( - text['attention_mask']).to(device) - features = self(**text) - return features['projection_state'] - - def forward( - self, - input_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - head_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - return_dict: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - ) : - r""" - """ - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - - outputs = self.roberta( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=True, - return_dict=return_dict, - ) - - # last module outputs - sequence_output = outputs[0] - - - # project every module - sequence_output_ln = self.pre_LN(sequence_output) - - # pooler - pooler_output = self.pooler(sequence_output_ln) - pooler_output = self.transformation(pooler_output) - projection_state = self.transformation(outputs.last_hidden_state) - - return { - 'pooler_output':pooler_output, - 'last_hidden_state':outputs.last_hidden_state, - 'hidden_states':outputs.hidden_states, - 'attentions':outputs.attentions, - 'projection_state':projection_state, - 'sequence_out': sequence_output - } - - -class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): - base_model_prefix = 'roberta' - config_class= RobertaSeriesConfig \ No newline at end of file diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py deleted file mode 100644 index 7836cada..00000000 --- a/ldm/modules/image_degradation/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr -from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py deleted file mode 100644 index 32ef5616..00000000 --- a/ldm/modules/image_degradation/bsrgan.py +++ /dev/null @@ -1,730 +0,0 @@ -# -*- coding: utf-8 -*- -""" -# -------------------------------------------- -# Super-Resolution -# -------------------------------------------- -# -# Kai Zhang (cskaizhang@gmail.com) -# https://github.com/cszn -# From 2019/03--2021/08 -# -------------------------------------------- -""" - -import numpy as np -import cv2 -import torch - -from functools import partial -import random -from scipy import ndimage -import scipy -import scipy.stats as ss -from scipy.interpolate import interp2d -from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util - - -def modcrop_np(img, sf): - ''' - Args: - img: numpy image, WxH or WxHxC - sf: scale factor - Return: - cropped image - ''' - w, h = img.shape[:2] - im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] - - -""" -# -------------------------------------------- -# anisotropic Gaussian kernels -# -------------------------------------------- -""" - - -def analytic_kernel(k): - """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" - k_size = k.shape[0] - # Calculate the big kernels size - big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) - # Loop over the small kernel to fill the big one - for r in range(k_size): - for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k - # Crop the edges of the big kernel to ignore very small values and increase run time of SR - crop = k_size // 2 - cropped_big_k = big_k[crop:-crop, crop:-crop] - # Normalize to 1 - return cropped_big_k / cropped_big_k.sum() - - -def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel - Args: - ksize : e.g., 15, kernel size - theta : [0, pi], rotation angle range - l1 : [0.1,50], scaling of eigenvalues - l2 : [0.1,l1], scaling of eigenvalues - If l1 = l2, will get an isotropic Gaussian kernel. - Returns: - k : kernel - """ - - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) - V = np.array([[v[0], v[1]], [v[1], -v[0]]]) - D = np.array([[l1, 0], [0, l2]]) - Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) - k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) - - return k - - -def gm_blur_kernel(mean, cov, size=15): - center = size / 2.0 + 0.5 - k = np.zeros([size, size]) - for y in range(size): - for x in range(size): - cy = y - center + 1 - cx = x - center + 1 - k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) - - k = k / np.sum(k) - return k - - -def shift_pixel(x, sf, upper_left=True): - """shift pixel for super-resolution with different scale factors - Args: - x: WxHxC or WxH - sf: scale factor - upper_left: shift direction - """ - h, w = x.shape[:2] - shift = (sf - 1) * 0.5 - xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) - if upper_left: - x1 = xv + shift - y1 = yv + shift - else: - x1 = xv - shift - y1 = yv - shift - - x1 = np.clip(x1, 0, w - 1) - y1 = np.clip(y1, 0, h - 1) - - if x.ndim == 2: - x = interp2d(xv, yv, x)(x1, y1) - if x.ndim == 3: - for i in range(x.shape[-1]): - x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) - - return x - - -def blur(x, k): - ''' - x: image, NxcxHxW - k: kernel, Nx1xhxw - ''' - n, c = x.shape[:2] - p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') - k = k.repeat(1, c, 1, 1) - k = k.view(-1, 1, k.shape[2], k.shape[3]) - x = x.view(1, -1, x.shape[2], x.shape[3]) - x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) - x = x.view(n, c, x.shape[2], x.shape[3]) - - return x - - -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" - # modified version of https://github.com/assafshocher/BlindSR_dataset_generator - # Kai Zhang - # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var - # max_var = 2.5 * sf - """ - # Set random eigen-vals (lambdas) and angle (theta) for COV matrix - lambda_1 = min_var + np.random.rand() * (max_var - min_var) - lambda_2 = min_var + np.random.rand() * (max_var - min_var) - theta = np.random.rand() * np.pi # random theta - noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 - - # Set COV matrix using Lambdas and Theta - LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) - SIGMA = Q @ LAMBDA @ Q.T - INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] - - # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) - MU = MU[None, None, :, None] - - # Create meshgrid for Gaussian - [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) - Z = np.stack([X, Y], 2)[:, :, :, None] - - # Calcualte Gaussian for every pixel of the kernel - ZZ = Z - MU - ZZ_t = ZZ.transpose(0, 1, 3, 2) - raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - - # shift the kernel so it will be centered - # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel - - -def fspecial_gaussian(hsize, sigma): - hsize = [hsize, hsize] - siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] - std = sigma - [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) - arg = -(x * x + y * y) / (2 * std * std) - h = np.exp(arg) - h[h < scipy.finfo(float).eps * h.max()] = 0 - sumh = h.sum() - if sumh != 0: - h = h / sumh - return h - - -def fspecial_laplacian(alpha): - alpha = max([0, min([alpha, 1])]) - h1 = alpha / (alpha + 1) - h2 = (1 - alpha) / (alpha + 1) - h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] - h = np.array(h) - return h - - -def fspecial(filter_type, *args, **kwargs): - ''' - python code from: - https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': - return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': - return fspecial_laplacian(*args, **kwargs) - - -""" -# -------------------------------------------- -# degradation models -# -------------------------------------------- -""" - - -def bicubic_degradation(x, sf=3): - ''' - Args: - x: HxWxC image, [0, 1] - sf: down-scale factor - Return: - bicubicly downsampled LR image - ''' - x = util.imresize_np(x, scale=1 / sf) - return x - - -def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - Reference: - @inproceedings{zhang2018learning, - title={Learning a single convolutional super-resolution network for multiple degradations}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={3262--3271}, - year={2018} - } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' - x = bicubic_degradation(x, sf=sf) - return x - - -def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - Reference: - @inproceedings{zhang2019deep, - title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={1671--1681}, - year={2019} - } - ''' - x = bicubic_degradation(x, sf=sf) - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - return x - - -def classical_degradation(x, k, sf=3): - ''' blur + downsampling - Args: - x: HxWxC image, [0, 1]/[0, 255] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) - st = 0 - return x[st::sf, st::sf, ...] - - -def add_sharpening(img, weight=0.5, radius=50, threshold=10): - """USM sharpening. borrowed from real-ESRGAN - Input image: I; Blurry image: B. - 1. K = I + weight * (I - B) - 2. Mask = 1 if abs(I - B) > threshold, else: 0 - 3. Blur mask: - 4. Out = Mask * K + (1 - Mask) * I - Args: - img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. - weight (float): Sharp weight. Default: 1. - radius (float): Kernel size of Gaussian blur. Default: 50. - threshold (int): - """ - if radius % 2 == 0: - radius += 1 - blur = cv2.GaussianBlur(img, (radius, radius), 0) - residual = img - blur - mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') - soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) - - K = img + weight * residual - K = np.clip(K, 0, 1) - return soft_mask * K + (1 - soft_mask) * img - - -def add_blur(img, sf=4): - wd2 = 4.0 + sf - wd = 2.0 + 0.2 * sf - if random.random() < 0.5: - l1 = wd2 * random.random() - l2 = wd2 * random.random() - k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) - else: - k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') - - return img - - -def add_resize(img, sf=4): - rnum = np.random.rand() - if rnum > 0.8: # up - sf1 = random.uniform(1, 2) - elif rnum < 0.7: # down - sf1 = random.uniform(0.5 / sf, 1) - else: - sf1 = 1.0 - img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) - img = np.clip(img, 0.0, 1.0) - - return img - - -# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): -# noise_level = random.randint(noise_level1, noise_level2) -# rnum = np.random.rand() -# if rnum > 0.6: # add color Gaussian noise -# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) -# elif rnum < 0.4: # add grayscale Gaussian noise -# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) -# else: # add noise -# L = noise_level2 / 255. -# D = np.diag(np.random.rand(3)) -# U = orth(np.random.rand(3, 3)) -# conv = np.dot(np.dot(np.transpose(U), D), U) -# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) -# img = np.clip(img, 0.0, 1.0) -# return img - -def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): - noise_level = random.randint(noise_level1, noise_level2) - rnum = np.random.rand() - if rnum > 0.6: # add color Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) - elif rnum < 0.4: # add grayscale Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) - else: # add noise - L = noise_level2 / 255. - D = np.diag(np.random.rand(3)) - U = orth(np.random.rand(3, 3)) - conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) - img = np.clip(img, 0.0, 1.0) - return img - - -def add_speckle_noise(img, noise_level1=2, noise_level2=25): - noise_level = random.randint(noise_level1, noise_level2) - img = np.clip(img, 0.0, 1.0) - rnum = random.random() - if rnum > 0.6: - img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) - elif rnum < 0.4: - img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) - else: - L = noise_level2 / 255. - D = np.diag(np.random.rand(3)) - U = orth(np.random.rand(3, 3)) - conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) - img = np.clip(img, 0.0, 1.0) - return img - - -def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. - vals = 10 ** (2 * random.random() + 2.0) # [2, 4] - if random.random() < 0.5: - img = np.random.poisson(img * vals).astype(np.float32) / vals - else: - img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. - noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray - img += noise_gray[:, :, np.newaxis] - img = np.clip(img, 0.0, 1.0) - return img - - -def add_JPEG_noise(img): - quality_factor = random.randint(30, 95) - img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) - img = cv2.imdecode(encimg, 1) - img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) - return img - - -def random_crop(lq, hq, sf=4, lq_patchsize=64): - h, w = lq.shape[:2] - rnd_h = random.randint(0, h - lq_patchsize) - rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] - - rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] - return lq, hq - - -def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): - """ - This is the degradation model of BSRGAN from the paper - "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" - ---------- - img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) - sf: scale factor - isp_model: camera ISP model - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf - - h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop - h, w = img.shape[:2] - - if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') - - hq = img.copy() - - if sf == 4 and random.random() < scale2_prob: # downsample1 - if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - img = util.imresize_np(img, 1 / 2, True) - img = np.clip(img, 0.0, 1.0) - sf = 2 - - shuffle_order = random.sample(range(7), 7) - idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] - - for i in shuffle_order: - - if i == 0: - img = add_blur(img, sf=sf) - - elif i == 1: - img = add_blur(img, sf=sf) - - elif i == 2: - a, b = img.shape[1], img.shape[0] - # downsample2 - if random.random() < 0.75: - sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) - k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') - img = img[0::sf, 0::sf, ...] # nearest downsampling - img = np.clip(img, 0.0, 1.0) - - elif i == 3: - # downsample3 - img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) - img = np.clip(img, 0.0, 1.0) - - elif i == 4: - # add Gaussian noise - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - - elif i == 5: - # add JPEG noise - if random.random() < jpeg_prob: - img = add_JPEG_noise(img) - - elif i == 6: - # add processed camera sensor noise - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - - # add final JPEG compression noise - img = add_JPEG_noise(img) - - # random crop - img, hq = random_crop(img, hq, sf_ori, lq_patchsize) - - return img, hq - - -# todo no isp_model? -def degradation_bsrgan_variant(image, sf=4, isp_model=None): - """ - This is the degradation model of BSRGAN from the paper - "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" - ---------- - sf: scale factor - isp_model: camera ISP model - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - image = util.uint2single(image) - isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf - - h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop - h, w = image.shape[:2] - - hq = image.copy() - - if sf == 4 and random.random() < scale2_prob: # downsample1 - if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - image = util.imresize_np(image, 1 / 2, True) - image = np.clip(image, 0.0, 1.0) - sf = 2 - - shuffle_order = random.sample(range(7), 7) - idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] - - for i in shuffle_order: - - if i == 0: - image = add_blur(image, sf=sf) - - elif i == 1: - image = add_blur(image, sf=sf) - - elif i == 2: - a, b = image.shape[1], image.shape[0] - # downsample2 - if random.random() < 0.75: - sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) - k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') - image = image[0::sf, 0::sf, ...] # nearest downsampling - image = np.clip(image, 0.0, 1.0) - - elif i == 3: - # downsample3 - image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) - image = np.clip(image, 0.0, 1.0) - - elif i == 4: - # add Gaussian noise - image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) - - elif i == 5: - # add JPEG noise - if random.random() < jpeg_prob: - image = add_JPEG_noise(image) - - # elif i == 6: - # # add processed camera sensor noise - # if random.random() < isp_prob and isp_model is not None: - # with torch.no_grad(): - # img, hq = isp_model.forward(img.copy(), hq) - - # add final JPEG compression noise - image = add_JPEG_noise(image) - image = util.single2uint(image) - example = {"image":image} - return example - - -# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... -def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): - """ - This is an extended degradation model by combining - the degradation models of BSRGAN and Real-ESRGAN - ---------- - img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) - sf: scale factor - use_shuffle: the degradation shuffle - use_sharp: sharpening the img - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - - h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop - h, w = img.shape[:2] - - if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') - - if use_sharp: - img = add_sharpening(img) - hq = img.copy() - - if random.random() < shuffle_prob: - shuffle_order = random.sample(range(13), 13) - else: - shuffle_order = list(range(13)) - # local shuffle for noise, JPEG is always the last one - shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) - shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) - - poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 - - for i in shuffle_order: - if i == 0: - img = add_blur(img, sf=sf) - elif i == 1: - img = add_resize(img, sf=sf) - elif i == 2: - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - elif i == 3: - if random.random() < poisson_prob: - img = add_Poisson_noise(img) - elif i == 4: - if random.random() < speckle_prob: - img = add_speckle_noise(img) - elif i == 5: - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - elif i == 6: - img = add_JPEG_noise(img) - elif i == 7: - img = add_blur(img, sf=sf) - elif i == 8: - img = add_resize(img, sf=sf) - elif i == 9: - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) - elif i == 10: - if random.random() < poisson_prob: - img = add_Poisson_noise(img) - elif i == 11: - if random.random() < speckle_prob: - img = add_speckle_noise(img) - elif i == 12: - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - else: - print('check the shuffle!') - - # resize to desired size - img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), - interpolation=random.choice([1, 2, 3])) - - # add final JPEG compression noise - img = add_JPEG_noise(img) - - # random crop - img, hq = random_crop(img, hq, sf, lq_patchsize) - - return img, hq - - -if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - print(img) - img = util.uint2single(img) - print(img) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_lq = deg_fn(img) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') - - diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py deleted file mode 100644 index 9e1f8239..00000000 --- a/ldm/modules/image_degradation/bsrgan_light.py +++ /dev/null @@ -1,650 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import cv2 -import torch - -from functools import partial -import random -from scipy import ndimage -import scipy -import scipy.stats as ss -from scipy.interpolate import interp2d -from scipy.linalg import orth -import albumentations - -import ldm.modules.image_degradation.utils_image as util - -""" -# -------------------------------------------- -# Super-Resolution -# -------------------------------------------- -# -# Kai Zhang (cskaizhang@gmail.com) -# https://github.com/cszn -# From 2019/03--2021/08 -# -------------------------------------------- -""" - - -def modcrop_np(img, sf): - ''' - Args: - img: numpy image, WxH or WxHxC - sf: scale factor - Return: - cropped image - ''' - w, h = img.shape[:2] - im = np.copy(img) - return im[:w - w % sf, :h - h % sf, ...] - - -""" -# -------------------------------------------- -# anisotropic Gaussian kernels -# -------------------------------------------- -""" - - -def analytic_kernel(k): - """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" - k_size = k.shape[0] - # Calculate the big kernels size - big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) - # Loop over the small kernel to fill the big one - for r in range(k_size): - for c in range(k_size): - big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k - # Crop the edges of the big kernel to ignore very small values and increase run time of SR - crop = k_size // 2 - cropped_big_k = big_k[crop:-crop, crop:-crop] - # Normalize to 1 - return cropped_big_k / cropped_big_k.sum() - - -def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): - """ generate an anisotropic Gaussian kernel - Args: - ksize : e.g., 15, kernel size - theta : [0, pi], rotation angle range - l1 : [0.1,50], scaling of eigenvalues - l2 : [0.1,l1], scaling of eigenvalues - If l1 = l2, will get an isotropic Gaussian kernel. - Returns: - k : kernel - """ - - v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) - V = np.array([[v[0], v[1]], [v[1], -v[0]]]) - D = np.array([[l1, 0], [0, l2]]) - Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) - k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) - - return k - - -def gm_blur_kernel(mean, cov, size=15): - center = size / 2.0 + 0.5 - k = np.zeros([size, size]) - for y in range(size): - for x in range(size): - cy = y - center + 1 - cx = x - center + 1 - k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) - - k = k / np.sum(k) - return k - - -def shift_pixel(x, sf, upper_left=True): - """shift pixel for super-resolution with different scale factors - Args: - x: WxHxC or WxH - sf: scale factor - upper_left: shift direction - """ - h, w = x.shape[:2] - shift = (sf - 1) * 0.5 - xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) - if upper_left: - x1 = xv + shift - y1 = yv + shift - else: - x1 = xv - shift - y1 = yv - shift - - x1 = np.clip(x1, 0, w - 1) - y1 = np.clip(y1, 0, h - 1) - - if x.ndim == 2: - x = interp2d(xv, yv, x)(x1, y1) - if x.ndim == 3: - for i in range(x.shape[-1]): - x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) - - return x - - -def blur(x, k): - ''' - x: image, NxcxHxW - k: kernel, Nx1xhxw - ''' - n, c = x.shape[:2] - p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 - x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') - k = k.repeat(1, c, 1, 1) - k = k.view(-1, 1, k.shape[2], k.shape[3]) - x = x.view(1, -1, x.shape[2], x.shape[3]) - x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) - x = x.view(n, c, x.shape[2], x.shape[3]) - - return x - - -def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): - """" - # modified version of https://github.com/assafshocher/BlindSR_dataset_generator - # Kai Zhang - # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var - # max_var = 2.5 * sf - """ - # Set random eigen-vals (lambdas) and angle (theta) for COV matrix - lambda_1 = min_var + np.random.rand() * (max_var - min_var) - lambda_2 = min_var + np.random.rand() * (max_var - min_var) - theta = np.random.rand() * np.pi # random theta - noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 - - # Set COV matrix using Lambdas and Theta - LAMBDA = np.diag([lambda_1, lambda_2]) - Q = np.array([[np.cos(theta), -np.sin(theta)], - [np.sin(theta), np.cos(theta)]]) - SIGMA = Q @ LAMBDA @ Q.T - INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] - - # Set expectation position (shifting kernel for aligned image) - MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) - MU = MU[None, None, :, None] - - # Create meshgrid for Gaussian - [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) - Z = np.stack([X, Y], 2)[:, :, :, None] - - # Calcualte Gaussian for every pixel of the kernel - ZZ = Z - MU - ZZ_t = ZZ.transpose(0, 1, 3, 2) - raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) - - # shift the kernel so it will be centered - # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) - - # Normalize the kernel and return - # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) - kernel = raw_kernel / np.sum(raw_kernel) - return kernel - - -def fspecial_gaussian(hsize, sigma): - hsize = [hsize, hsize] - siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] - std = sigma - [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) - arg = -(x * x + y * y) / (2 * std * std) - h = np.exp(arg) - h[h < scipy.finfo(float).eps * h.max()] = 0 - sumh = h.sum() - if sumh != 0: - h = h / sumh - return h - - -def fspecial_laplacian(alpha): - alpha = max([0, min([alpha, 1])]) - h1 = alpha / (alpha + 1) - h2 = (1 - alpha) / (alpha + 1) - h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] - h = np.array(h) - return h - - -def fspecial(filter_type, *args, **kwargs): - ''' - python code from: - https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py - ''' - if filter_type == 'gaussian': - return fspecial_gaussian(*args, **kwargs) - if filter_type == 'laplacian': - return fspecial_laplacian(*args, **kwargs) - - -""" -# -------------------------------------------- -# degradation models -# -------------------------------------------- -""" - - -def bicubic_degradation(x, sf=3): - ''' - Args: - x: HxWxC image, [0, 1] - sf: down-scale factor - Return: - bicubicly downsampled LR image - ''' - x = util.imresize_np(x, scale=1 / sf) - return x - - -def srmd_degradation(x, k, sf=3): - ''' blur + bicubic downsampling - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - Reference: - @inproceedings{zhang2018learning, - title={Learning a single convolutional super-resolution network for multiple degradations}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={3262--3271}, - year={2018} - } - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' - x = bicubic_degradation(x, sf=sf) - return x - - -def dpsr_degradation(x, k, sf=3): - ''' bicubic downsampling + blur - Args: - x: HxWxC image, [0, 1] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - Reference: - @inproceedings{zhang2019deep, - title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, - author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, - booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, - pages={1671--1681}, - year={2019} - } - ''' - x = bicubic_degradation(x, sf=sf) - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - return x - - -def classical_degradation(x, k, sf=3): - ''' blur + downsampling - Args: - x: HxWxC image, [0, 1]/[0, 255] - k: hxw, double - sf: down-scale factor - Return: - downsampled LR image - ''' - x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') - # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) - st = 0 - return x[st::sf, st::sf, ...] - - -def add_sharpening(img, weight=0.5, radius=50, threshold=10): - """USM sharpening. borrowed from real-ESRGAN - Input image: I; Blurry image: B. - 1. K = I + weight * (I - B) - 2. Mask = 1 if abs(I - B) > threshold, else: 0 - 3. Blur mask: - 4. Out = Mask * K + (1 - Mask) * I - Args: - img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. - weight (float): Sharp weight. Default: 1. - radius (float): Kernel size of Gaussian blur. Default: 50. - threshold (int): - """ - if radius % 2 == 0: - radius += 1 - blur = cv2.GaussianBlur(img, (radius, radius), 0) - residual = img - blur - mask = np.abs(residual) * 255 > threshold - mask = mask.astype('float32') - soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) - - K = img + weight * residual - K = np.clip(K, 0, 1) - return soft_mask * K + (1 - soft_mask) * img - - -def add_blur(img, sf=4): - wd2 = 4.0 + sf - wd = 2.0 + 0.2 * sf - - wd2 = wd2/4 - wd = wd/4 - - if random.random() < 0.5: - l1 = wd2 * random.random() - l2 = wd2 * random.random() - k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) - else: - k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) - img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') - - return img - - -def add_resize(img, sf=4): - rnum = np.random.rand() - if rnum > 0.8: # up - sf1 = random.uniform(1, 2) - elif rnum < 0.7: # down - sf1 = random.uniform(0.5 / sf, 1) - else: - sf1 = 1.0 - img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) - img = np.clip(img, 0.0, 1.0) - - return img - - -# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): -# noise_level = random.randint(noise_level1, noise_level2) -# rnum = np.random.rand() -# if rnum > 0.6: # add color Gaussian noise -# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) -# elif rnum < 0.4: # add grayscale Gaussian noise -# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) -# else: # add noise -# L = noise_level2 / 255. -# D = np.diag(np.random.rand(3)) -# U = orth(np.random.rand(3, 3)) -# conv = np.dot(np.dot(np.transpose(U), D), U) -# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) -# img = np.clip(img, 0.0, 1.0) -# return img - -def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): - noise_level = random.randint(noise_level1, noise_level2) - rnum = np.random.rand() - if rnum > 0.6: # add color Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) - elif rnum < 0.4: # add grayscale Gaussian noise - img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) - else: # add noise - L = noise_level2 / 255. - D = np.diag(np.random.rand(3)) - U = orth(np.random.rand(3, 3)) - conv = np.dot(np.dot(np.transpose(U), D), U) - img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) - img = np.clip(img, 0.0, 1.0) - return img - - -def add_speckle_noise(img, noise_level1=2, noise_level2=25): - noise_level = random.randint(noise_level1, noise_level2) - img = np.clip(img, 0.0, 1.0) - rnum = random.random() - if rnum > 0.6: - img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) - elif rnum < 0.4: - img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) - else: - L = noise_level2 / 255. - D = np.diag(np.random.rand(3)) - U = orth(np.random.rand(3, 3)) - conv = np.dot(np.dot(np.transpose(U), D), U) - img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) - img = np.clip(img, 0.0, 1.0) - return img - - -def add_Poisson_noise(img): - img = np.clip((img * 255.0).round(), 0, 255) / 255. - vals = 10 ** (2 * random.random() + 2.0) # [2, 4] - if random.random() < 0.5: - img = np.random.poisson(img * vals).astype(np.float32) / vals - else: - img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) - img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. - noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray - img += noise_gray[:, :, np.newaxis] - img = np.clip(img, 0.0, 1.0) - return img - - -def add_JPEG_noise(img): - quality_factor = random.randint(80, 95) - img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) - result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) - img = cv2.imdecode(encimg, 1) - img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) - return img - - -def random_crop(lq, hq, sf=4, lq_patchsize=64): - h, w = lq.shape[:2] - rnd_h = random.randint(0, h - lq_patchsize) - rnd_w = random.randint(0, w - lq_patchsize) - lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] - - rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) - hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] - return lq, hq - - -def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): - """ - This is the degradation model of BSRGAN from the paper - "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" - ---------- - img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) - sf: scale factor - isp_model: camera ISP model - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf - - h1, w1 = img.shape[:2] - img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop - h, w = img.shape[:2] - - if h < lq_patchsize * sf or w < lq_patchsize * sf: - raise ValueError(f'img size ({h1}X{w1}) is too small!') - - hq = img.copy() - - if sf == 4 and random.random() < scale2_prob: # downsample1 - if np.random.rand() < 0.5: - img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - img = util.imresize_np(img, 1 / 2, True) - img = np.clip(img, 0.0, 1.0) - sf = 2 - - shuffle_order = random.sample(range(7), 7) - idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] - - for i in shuffle_order: - - if i == 0: - img = add_blur(img, sf=sf) - - elif i == 1: - img = add_blur(img, sf=sf) - - elif i == 2: - a, b = img.shape[1], img.shape[0] - # downsample2 - if random.random() < 0.75: - sf1 = random.uniform(1, 2 * sf) - img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) - k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') - img = img[0::sf, 0::sf, ...] # nearest downsampling - img = np.clip(img, 0.0, 1.0) - - elif i == 3: - # downsample3 - img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) - img = np.clip(img, 0.0, 1.0) - - elif i == 4: - # add Gaussian noise - img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) - - elif i == 5: - # add JPEG noise - if random.random() < jpeg_prob: - img = add_JPEG_noise(img) - - elif i == 6: - # add processed camera sensor noise - if random.random() < isp_prob and isp_model is not None: - with torch.no_grad(): - img, hq = isp_model.forward(img.copy(), hq) - - # add final JPEG compression noise - img = add_JPEG_noise(img) - - # random crop - img, hq = random_crop(img, hq, sf_ori, lq_patchsize) - - return img, hq - - -# todo no isp_model? -def degradation_bsrgan_variant(image, sf=4, isp_model=None): - """ - This is the degradation model of BSRGAN from the paper - "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" - ---------- - sf: scale factor - isp_model: camera ISP model - Returns - ------- - img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] - hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] - """ - image = util.uint2single(image) - isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 - sf_ori = sf - - h1, w1 = image.shape[:2] - image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop - h, w = image.shape[:2] - - hq = image.copy() - - if sf == 4 and random.random() < scale2_prob: # downsample1 - if np.random.rand() < 0.5: - image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - image = util.imresize_np(image, 1 / 2, True) - image = np.clip(image, 0.0, 1.0) - sf = 2 - - shuffle_order = random.sample(range(7), 7) - idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) - if idx1 > idx2: # keep downsample3 last - shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] - - for i in shuffle_order: - - if i == 0: - image = add_blur(image, sf=sf) - - # elif i == 1: - # image = add_blur(image, sf=sf) - - if i == 0: - pass - - elif i == 2: - a, b = image.shape[1], image.shape[0] - # downsample2 - if random.random() < 0.8: - sf1 = random.uniform(1, 2 * sf) - image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), - interpolation=random.choice([1, 2, 3])) - else: - k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) - k_shifted = shift_pixel(k, sf) - k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel - image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') - image = image[0::sf, 0::sf, ...] # nearest downsampling - - image = np.clip(image, 0.0, 1.0) - - elif i == 3: - # downsample3 - image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) - image = np.clip(image, 0.0, 1.0) - - elif i == 4: - # add Gaussian noise - image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) - - elif i == 5: - # add JPEG noise - if random.random() < jpeg_prob: - image = add_JPEG_noise(image) - # - # elif i == 6: - # # add processed camera sensor noise - # if random.random() < isp_prob and isp_model is not None: - # with torch.no_grad(): - # img, hq = isp_model.forward(img.copy(), hq) - - # add final JPEG compression noise - image = add_JPEG_noise(image) - image = util.single2uint(image) - example = {"image": image} - return example - - - - -if __name__ == '__main__': - print("hey") - img = util.imread_uint('utils/test.png', 3) - img = img[:448, :448] - h = img.shape[0] // 4 - print("resizing to", h) - sf = 4 - deg_fn = partial(degradation_bsrgan_variant, sf=sf) - for i in range(20): - print(i) - img_hq = img - img_lq = deg_fn(img)["image"] - img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) - print(img_lq) - img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] - print(img_lq.shape) - print("bicubic", img_lq_bicubic.shape) - print(img_hq.shape) - lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), - (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), - interpolation=0) - img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) - util.imsave(img_concat, str(i) + '.png') diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png deleted file mode 100644 index 4249b43d..00000000 Binary files a/ldm/modules/image_degradation/utils/test.png and /dev/null differ diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py deleted file mode 100644 index 0175f155..00000000 --- a/ldm/modules/image_degradation/utils_image.py +++ /dev/null @@ -1,916 +0,0 @@ -import os -import math -import random -import numpy as np -import torch -import cv2 -from torchvision.utils import make_grid -from datetime import datetime -#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py - - -os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" - - -''' -# -------------------------------------------- -# Kai Zhang (github: https://github.com/cszn) -# 03/Mar/2019 -# -------------------------------------------- -# https://github.com/twhui/SRGAN-pyTorch -# https://github.com/xinntao/BasicSR -# -------------------------------------------- -''' - - -IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] - - -def is_image_file(filename): - return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) - - -def get_timestamp(): - return datetime.now().strftime('%y%m%d-%H%M%S') - - -def imshow(x, title=None, cbar=False, figsize=None): - plt.figure(figsize=figsize) - plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') - if title: - plt.title(title) - if cbar: - plt.colorbar() - plt.show() - - -def surf(Z, cmap='rainbow', figsize=None): - plt.figure(figsize=figsize) - ax3 = plt.axes(projection='3d') - - w, h = Z.shape[:2] - xx = np.arange(0,w,1) - yy = np.arange(0,h,1) - X, Y = np.meshgrid(xx, yy) - ax3.plot_surface(X,Y,Z,cmap=cmap) - #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) - plt.show() - - -''' -# -------------------------------------------- -# get image pathes -# -------------------------------------------- -''' - - -def get_image_paths(dataroot): - paths = None # return None if dataroot is None - if dataroot is not None: - paths = sorted(_get_paths_from_images(dataroot)) - return paths - - -def _get_paths_from_images(path): - assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) - images = [] - for dirpath, _, fnames in sorted(os.walk(path)): - for fname in sorted(fnames): - if is_image_file(fname): - img_path = os.path.join(dirpath, fname) - images.append(img_path) - assert images, '{:s} has no valid image file'.format(path) - return images - - -''' -# -------------------------------------------- -# split large images into small images -# -------------------------------------------- -''' - - -def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): - w, h = img.shape[:2] - patches = [] - if w > p_max and h > p_max: - w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) - h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) - w1.append(w-p_size) - h1.append(h-p_size) -# print(w1) -# print(h1) - for i in w1: - for j in h1: - patches.append(img[i:i+p_size, j:j+p_size,:]) - else: - patches.append(img) - - return patches - - -def imssave(imgs, img_path): - """ - imgs: list, N images of size WxHxC - """ - img_name, ext = os.path.splitext(os.path.basename(img_path)) - - for i, img in enumerate(imgs): - if img.ndim == 3: - img = img[:, :, [2, 1, 0]] - new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') - cv2.imwrite(new_path, img) - - -def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): - """ - split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), - and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) - will be splitted. - Args: - original_dataroot: - taget_dataroot: - p_size: size of small images - p_overlap: patch size in training is a good choice - p_max: images with smaller size than (p_max)x(p_max) keep unchanged. - """ - paths = get_image_paths(original_dataroot) - for img_path in paths: - # img_name, ext = os.path.splitext(os.path.basename(img_path)) - img = imread_uint(img_path, n_channels=n_channels) - patches = patches_from_image(img, p_size, p_overlap, p_max) - imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) - #if original_dataroot == taget_dataroot: - #del img_path - -''' -# -------------------------------------------- -# makedir -# -------------------------------------------- -''' - - -def mkdir(path): - if not os.path.exists(path): - os.makedirs(path) - - -def mkdirs(paths): - if isinstance(paths, str): - mkdir(paths) - else: - for path in paths: - mkdir(path) - - -def mkdir_and_rename(path): - if os.path.exists(path): - new_name = path + '_archived_' + get_timestamp() - print('Path already exists. Rename it to [{:s}]'.format(new_name)) - os.rename(path, new_name) - os.makedirs(path) - - -''' -# -------------------------------------------- -# read image from path -# opencv is fast, but read BGR numpy image -# -------------------------------------------- -''' - - -# -------------------------------------------- -# get uint8 image of size HxWxn_channles (RGB) -# -------------------------------------------- -def imread_uint(path, n_channels=3): - # input: path - # output: HxWx3(RGB or GGG), or HxWx1 (G) - if n_channels == 1: - img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE - img = np.expand_dims(img, axis=2) # HxWx1 - elif n_channels == 3: - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G - if img.ndim == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG - else: - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB - return img - - -# -------------------------------------------- -# matlab's imwrite -# -------------------------------------------- -def imsave(img, img_path): - img = np.squeeze(img) - if img.ndim == 3: - img = img[:, :, [2, 1, 0]] - cv2.imwrite(img_path, img) - -def imwrite(img, img_path): - img = np.squeeze(img) - if img.ndim == 3: - img = img[:, :, [2, 1, 0]] - cv2.imwrite(img_path, img) - - - -# -------------------------------------------- -# get single image of size HxWxn_channles (BGR) -# -------------------------------------------- -def read_img(path): - # read image by cv2 - # return: Numpy float32, HWC, BGR, [0,1] - img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE - img = img.astype(np.float32) / 255. - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - # some images have 4 channels - if img.shape[2] > 3: - img = img[:, :, :3] - return img - - -''' -# -------------------------------------------- -# image format conversion -# -------------------------------------------- -# numpy(single) <---> numpy(unit) -# numpy(single) <---> tensor -# numpy(unit) <---> tensor -# -------------------------------------------- -''' - - -# -------------------------------------------- -# numpy(single) [0, 1] <---> numpy(unit) -# -------------------------------------------- - - -def uint2single(img): - - return np.float32(img/255.) - - -def single2uint(img): - - return np.uint8((img.clip(0, 1)*255.).round()) - - -def uint162single(img): - - return np.float32(img/65535.) - - -def single2uint16(img): - - return np.uint16((img.clip(0, 1)*65535.).round()) - - -# -------------------------------------------- -# numpy(unit) (HxWxC or HxW) <---> tensor -# -------------------------------------------- - - -# convert uint to 4-dimensional torch tensor -def uint2tensor4(img): - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) - - -# convert uint to 3-dimensional torch tensor -def uint2tensor3(img): - if img.ndim == 2: - img = np.expand_dims(img, axis=2) - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) - - -# convert 2/3/4-dimensional torch tensor to uint -def tensor2uint(img): - img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() - if img.ndim == 3: - img = np.transpose(img, (1, 2, 0)) - return np.uint8((img*255.0).round()) - - -# -------------------------------------------- -# numpy(single) (HxWxC) <---> tensor -# -------------------------------------------- - - -# convert single (HxWxC) to 3-dimensional torch tensor -def single2tensor3(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() - - -# convert single (HxWxC) to 4-dimensional torch tensor -def single2tensor4(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) - - -# convert torch tensor to single -def tensor2single(img): - img = img.data.squeeze().float().cpu().numpy() - if img.ndim == 3: - img = np.transpose(img, (1, 2, 0)) - - return img - -# convert torch tensor to single -def tensor2single3(img): - img = img.data.squeeze().float().cpu().numpy() - if img.ndim == 3: - img = np.transpose(img, (1, 2, 0)) - elif img.ndim == 2: - img = np.expand_dims(img, axis=2) - return img - - -def single2tensor5(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) - - -def single32tensor5(img): - return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) - - -def single42tensor4(img): - return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() - - -# from skimage.io import imread, imsave -def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): - ''' - Converts a torch Tensor into an image Numpy array of BGR channel order - Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order - Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) - ''' - tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp - tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] - n_dim = tensor.dim() - if n_dim == 4: - n_img = len(tensor) - img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR - elif n_dim == 3: - img_np = tensor.numpy() - img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR - elif n_dim == 2: - img_np = tensor.numpy() - else: - raise TypeError( - 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) - if out_type == np.uint8: - img_np = (img_np * 255.0).round() - # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. - return img_np.astype(out_type) - - -''' -# -------------------------------------------- -# Augmentation, flipe and/or rotate -# -------------------------------------------- -# The following two are enough. -# (1) augmet_img: numpy image of WxHxC or WxH -# (2) augment_img_tensor4: tensor image 1xCxWxH -# -------------------------------------------- -''' - - -def augment_img(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' - if mode == 0: - return img - elif mode == 1: - return np.flipud(np.rot90(img)) - elif mode == 2: - return np.flipud(img) - elif mode == 3: - return np.rot90(img, k=3) - elif mode == 4: - return np.flipud(np.rot90(img, k=2)) - elif mode == 5: - return np.rot90(img) - elif mode == 6: - return np.rot90(img, k=2) - elif mode == 7: - return np.flipud(np.rot90(img, k=3)) - - -def augment_img_tensor4(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' - if mode == 0: - return img - elif mode == 1: - return img.rot90(1, [2, 3]).flip([2]) - elif mode == 2: - return img.flip([2]) - elif mode == 3: - return img.rot90(3, [2, 3]) - elif mode == 4: - return img.rot90(2, [2, 3]).flip([2]) - elif mode == 5: - return img.rot90(1, [2, 3]) - elif mode == 6: - return img.rot90(2, [2, 3]) - elif mode == 7: - return img.rot90(3, [2, 3]).flip([2]) - - -def augment_img_tensor(img, mode=0): - '''Kai Zhang (github: https://github.com/cszn) - ''' - img_size = img.size() - img_np = img.data.cpu().numpy() - if len(img_size) == 3: - img_np = np.transpose(img_np, (1, 2, 0)) - elif len(img_size) == 4: - img_np = np.transpose(img_np, (2, 3, 1, 0)) - img_np = augment_img(img_np, mode=mode) - img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) - if len(img_size) == 3: - img_tensor = img_tensor.permute(2, 0, 1) - elif len(img_size) == 4: - img_tensor = img_tensor.permute(3, 2, 0, 1) - - return img_tensor.type_as(img) - - -def augment_img_np3(img, mode=0): - if mode == 0: - return img - elif mode == 1: - return img.transpose(1, 0, 2) - elif mode == 2: - return img[::-1, :, :] - elif mode == 3: - img = img[::-1, :, :] - img = img.transpose(1, 0, 2) - return img - elif mode == 4: - return img[:, ::-1, :] - elif mode == 5: - img = img[:, ::-1, :] - img = img.transpose(1, 0, 2) - return img - elif mode == 6: - img = img[:, ::-1, :] - img = img[::-1, :, :] - return img - elif mode == 7: - img = img[:, ::-1, :] - img = img[::-1, :, :] - img = img.transpose(1, 0, 2) - return img - - -def augment_imgs(img_list, hflip=True, rot=True): - # horizontal flip OR rotate - hflip = hflip and random.random() < 0.5 - vflip = rot and random.random() < 0.5 - rot90 = rot and random.random() < 0.5 - - def _augment(img): - if hflip: - img = img[:, ::-1, :] - if vflip: - img = img[::-1, :, :] - if rot90: - img = img.transpose(1, 0, 2) - return img - - return [_augment(img) for img in img_list] - - -''' -# -------------------------------------------- -# modcrop and shave -# -------------------------------------------- -''' - - -def modcrop(img_in, scale): - # img_in: Numpy, HWC or HW - img = np.copy(img_in) - if img.ndim == 2: - H, W = img.shape - H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r] - elif img.ndim == 3: - H, W, C = img.shape - H_r, W_r = H % scale, W % scale - img = img[:H - H_r, :W - W_r, :] - else: - raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) - return img - - -def shave(img_in, border=0): - # img_in: Numpy, HWC or HW - img = np.copy(img_in) - h, w = img.shape[:2] - img = img[border:h-border, border:w-border] - return img - - -''' -# -------------------------------------------- -# image processing process on numpy image -# channel_convert(in_c, tar_type, img_list): -# rgb2ycbcr(img, only_y=True): -# bgr2ycbcr(img, only_y=True): -# ycbcr2rgb(img): -# -------------------------------------------- -''' - - -def rgb2ycbcr(img, only_y=True): - '''same as matlab rgb2ycbcr - only_y: only return Y channel - Input: - uint8, [0, 255] - float, [0, 1] - ''' - in_img_type = img.dtype - img.astype(np.float32) - if in_img_type != np.uint8: - img *= 255. - # convert - if only_y: - rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 - else: - rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], - [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] - if in_img_type == np.uint8: - rlt = rlt.round() - else: - rlt /= 255. - return rlt.astype(in_img_type) - - -def ycbcr2rgb(img): - '''same as matlab ycbcr2rgb - Input: - uint8, [0, 255] - float, [0, 1] - ''' - in_img_type = img.dtype - img.astype(np.float32) - if in_img_type != np.uint8: - img *= 255. - # convert - rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], - [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] - if in_img_type == np.uint8: - rlt = rlt.round() - else: - rlt /= 255. - return rlt.astype(in_img_type) - - -def bgr2ycbcr(img, only_y=True): - '''bgr version of rgb2ycbcr - only_y: only return Y channel - Input: - uint8, [0, 255] - float, [0, 1] - ''' - in_img_type = img.dtype - img.astype(np.float32) - if in_img_type != np.uint8: - img *= 255. - # convert - if only_y: - rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 - else: - rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], - [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] - if in_img_type == np.uint8: - rlt = rlt.round() - else: - rlt /= 255. - return rlt.astype(in_img_type) - - -def channel_convert(in_c, tar_type, img_list): - # conversion among BGR, gray and y - if in_c == 3 and tar_type == 'gray': # BGR to gray - gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] - return [np.expand_dims(img, axis=2) for img in gray_list] - elif in_c == 3 and tar_type == 'y': # BGR to y - y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] - return [np.expand_dims(img, axis=2) for img in y_list] - elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR - return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] - else: - return img_list - - -''' -# -------------------------------------------- -# metric, PSNR and SSIM -# -------------------------------------------- -''' - - -# -------------------------------------------- -# PSNR -# -------------------------------------------- -def calculate_psnr(img1, img2, border=0): - # img1 and img2 have range [0, 255] - #img1 = img1.squeeze() - #img2 = img2.squeeze() - if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') - h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] - - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) - mse = np.mean((img1 - img2)**2) - if mse == 0: - return float('inf') - return 20 * math.log10(255.0 / math.sqrt(mse)) - - -# -------------------------------------------- -# SSIM -# -------------------------------------------- -def calculate_ssim(img1, img2, border=0): - '''calculate SSIM - the same outputs as MATLAB's - img1, img2: [0, 255] - ''' - #img1 = img1.squeeze() - #img2 = img2.squeeze() - if not img1.shape == img2.shape: - raise ValueError('Input images must have the same dimensions.') - h, w = img1.shape[:2] - img1 = img1[border:h-border, border:w-border] - img2 = img2[border:h-border, border:w-border] - - if img1.ndim == 2: - return ssim(img1, img2) - elif img1.ndim == 3: - if img1.shape[2] == 3: - ssims = [] - for i in range(3): - ssims.append(ssim(img1[:,:,i], img2[:,:,i])) - return np.array(ssims).mean() - elif img1.shape[2] == 1: - return ssim(np.squeeze(img1), np.squeeze(img2)) - else: - raise ValueError('Wrong input image dimensions.') - - -def ssim(img1, img2): - C1 = (0.01 * 255)**2 - C2 = (0.03 * 255)**2 - - img1 = img1.astype(np.float64) - img2 = img2.astype(np.float64) - kernel = cv2.getGaussianKernel(11, 1.5) - window = np.outer(kernel, kernel.transpose()) - - mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid - mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] - mu1_sq = mu1**2 - mu2_sq = mu2**2 - mu1_mu2 = mu1 * mu2 - sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq - sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq - sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 - - ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * - (sigma1_sq + sigma2_sq + C2)) - return ssim_map.mean() - - -''' -# -------------------------------------------- -# matlab's bicubic imresize (numpy and torch) [0, 1] -# -------------------------------------------- -''' - - -# matlab 'imresize' function, now only support 'bicubic' -def cubic(x): - absx = torch.abs(x) - absx2 = absx**2 - absx3 = absx**3 - return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ - (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) - - -def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): - if (scale < 1) and (antialiasing): - # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width - kernel_width = kernel_width / scale - - # Output-space coordinates - x = torch.linspace(1, out_length, out_length) - - # Input-space coordinates. Calculate the inverse mapping such that 0.5 - # in output space maps to 0.5 in input space, and 0.5+scale in output - # space maps to 1.5 in input space. - u = x / scale + 0.5 * (1 - 1 / scale) - - # What is the left-most pixel that can be involved in the computation? - left = torch.floor(u - kernel_width / 2) - - # What is the maximum number of pixels that can be involved in the - # computation? Note: it's OK to use an extra pixel here; if the - # corresponding weights are all zero, it will be eliminated at the end - # of this function. - P = math.ceil(kernel_width) + 2 - - # The indices of the input pixels involved in computing the k-th output - # pixel are in row k of the indices matrix. - indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( - 1, P).expand(out_length, P) - - # The weights used to compute the k-th output pixel are in row k of the - # weights matrix. - distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices - # apply cubic kernel - if (scale < 1) and (antialiasing): - weights = scale * cubic(distance_to_center * scale) - else: - weights = cubic(distance_to_center) - # Normalize the weights matrix so that each row sums to 1. - weights_sum = torch.sum(weights, 1).view(out_length, 1) - weights = weights / weights_sum.expand(out_length, P) - - # If a column in weights is all zero, get rid of it. only consider the first and last column. - weights_zero_tmp = torch.sum((weights == 0), 0) - if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): - indices = indices.narrow(1, 1, P - 2) - weights = weights.narrow(1, 1, P - 2) - if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): - indices = indices.narrow(1, 0, P - 2) - weights = weights.narrow(1, 0, P - 2) - weights = weights.contiguous() - indices = indices.contiguous() - sym_len_s = -indices.min() + 1 - sym_len_e = indices.max() - in_length - indices = indices + sym_len_s - 1 - return weights, indices, int(sym_len_s), int(sym_len_e) - - -# -------------------------------------------- -# imresize for tensor image [0, 1] -# -------------------------------------------- -def imresize(img, scale, antialiasing=True): - # Now the scale should be the same for H and W - # input: img: pytorch tensor, CHW or HW [0,1] - # output: CHW or HW [0,1] w/o round - need_squeeze = True if img.dim() == 2 else False - if need_squeeze: - img.unsqueeze_(0) - in_C, in_H, in_W = img.size() - out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) - kernel_width = 4 - kernel = 'cubic' - - # Return the desired dimension order for performing the resize. The - # strategy is to perform the resize first along the dimension with the - # smallest scale factor. - # Now we do not support this. - - # get weights and indices - weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) - weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) - # process H dimension - # symmetric copying - img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) - img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) - - sym_patch = img[:, :sym_len_Hs, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) - - sym_patch = img[:, -sym_len_He:, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) - - out_1 = torch.FloatTensor(in_C, out_H, in_W) - kernel_width = weights_H.size(1) - for i in range(out_H): - idx = int(indices_H[i][0]) - for j in range(out_C): - out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) - - # process W dimension - # symmetric copying - out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) - out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) - - sym_patch = out_1[:, :, :sym_len_Ws] - inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(2, inv_idx) - out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) - - sym_patch = out_1[:, :, -sym_len_We:] - inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(2, inv_idx) - out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) - - out_2 = torch.FloatTensor(in_C, out_H, out_W) - kernel_width = weights_W.size(1) - for i in range(out_W): - idx = int(indices_W[i][0]) - for j in range(out_C): - out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) - if need_squeeze: - out_2.squeeze_() - return out_2 - - -# -------------------------------------------- -# imresize for numpy image [0, 1] -# -------------------------------------------- -def imresize_np(img, scale, antialiasing=True): - # Now the scale should be the same for H and W - # input: img: Numpy, HWC or HW [0,1] - # output: HWC or HW [0,1] w/o round - img = torch.from_numpy(img) - need_squeeze = True if img.dim() == 2 else False - if need_squeeze: - img.unsqueeze_(2) - - in_H, in_W, in_C = img.size() - out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) - kernel_width = 4 - kernel = 'cubic' - - # Return the desired dimension order for performing the resize. The - # strategy is to perform the resize first along the dimension with the - # smallest scale factor. - # Now we do not support this. - - # get weights and indices - weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( - in_H, out_H, scale, kernel, kernel_width, antialiasing) - weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( - in_W, out_W, scale, kernel, kernel_width, antialiasing) - # process H dimension - # symmetric copying - img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) - img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) - - sym_patch = img[:sym_len_Hs, :, :] - inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(0, inv_idx) - img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) - - sym_patch = img[-sym_len_He:, :, :] - inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(0, inv_idx) - img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) - - out_1 = torch.FloatTensor(out_H, in_W, in_C) - kernel_width = weights_H.size(1) - for i in range(out_H): - idx = int(indices_H[i][0]) - for j in range(out_C): - out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) - - # process W dimension - # symmetric copying - out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) - out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) - - sym_patch = out_1[:, :sym_len_Ws, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) - - sym_patch = out_1[:, -sym_len_We:, :] - inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() - sym_patch_inv = sym_patch.index_select(1, inv_idx) - out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) - - out_2 = torch.FloatTensor(out_H, out_W, in_C) - kernel_width = weights_W.size(1) - for i in range(out_W): - idx = int(indices_W[i][0]) - for j in range(out_C): - out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) - if need_squeeze: - out_2.squeeze_() - - return out_2.numpy() - - -if __name__ == '__main__': - print('---') -# img = imread_uint('test.bmp', 3) -# img = uint2single(img) -# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/ldm/modules/losses/__init__.py b/ldm/modules/losses/__init__.py deleted file mode 100644 index 876d7c5b..00000000 --- a/ldm/modules/losses/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py deleted file mode 100644 index 672c1e32..00000000 --- a/ldm/modules/losses/contperceptual.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import torch.nn as nn - -from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? - - -class LPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_loss="hinge"): - - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - self.kl_weight = kl_weight - self.pixel_weight = pixelloss_weight - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - # output log variance - self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) - - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm - ).apply(weights_init) - self.discriminator_iter_start = disc_start - self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward(self, inputs, reconstructions, posteriors, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", - weights=None): - rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) - rec_loss = rec_loss + self.perceptual_weight * p_loss - - nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar - weighted_nll_loss = nll_loss - if weights is not None: - weighted_nll_loss = weights*nll_loss - weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - kl_loss = posteriors.kl() - kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) - g_loss = -torch.mean(logits_fake) - - if self.disc_factor > 0.0: - try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) - except RuntimeError: - assert not self.training - d_weight = torch.tensor(0.0) - else: - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss - - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), - "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } - return d_loss, log - diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py deleted file mode 100644 index f6998176..00000000 --- a/ldm/modules/losses/vqperceptual.py +++ /dev/null @@ -1,167 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F -from einops import repeat - -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init -from taming.modules.losses.lpips import LPIPS -from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss - - -def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): - assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] - loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) - loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) - loss_real = (weights * loss_real).sum() / weights.sum() - loss_fake = (weights * loss_fake).sum() / weights.sum() - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss - -def adopt_weight(weight, global_step, threshold=0, value=0.): - if global_step < threshold: - weight = value - return weight - - -def measure_perplexity(predicted_indices, n_embed): - # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py - # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally - encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) - avg_probs = encodings.mean(0) - perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() - cluster_use = torch.sum(avg_probs > 0) - return perplexity, cluster_use - -def l1(x, y): - return torch.abs(x-y) - - -def l2(x, y): - return torch.pow((x-y), 2) - - -class VQLPIPSWithDiscriminator(nn.Module): - def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, - disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, - disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", - pixel_loss="l1"): - super().__init__() - assert disc_loss in ["hinge", "vanilla"] - assert perceptual_loss in ["lpips", "clips", "dists"] - assert pixel_loss in ["l1", "l2"] - self.codebook_weight = codebook_weight - self.pixel_weight = pixelloss_weight - if perceptual_loss == "lpips": - print(f"{self.__class__.__name__}: Running with LPIPS.") - self.perceptual_loss = LPIPS().eval() - else: - raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") - self.perceptual_weight = perceptual_weight - - if pixel_loss == "l1": - self.pixel_loss = l1 - else: - self.pixel_loss = l2 - - self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, - n_layers=disc_num_layers, - use_actnorm=use_actnorm, - ndf=disc_ndf - ).apply(weights_init) - self.discriminator_iter_start = disc_start - if disc_loss == "hinge": - self.disc_loss = hinge_d_loss - elif disc_loss == "vanilla": - self.disc_loss = vanilla_d_loss - else: - raise ValueError(f"Unknown GAN loss '{disc_loss}'.") - print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.disc_conditional = disc_conditional - self.n_classes = n_classes - - def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): - if last_layer is not None: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - else: - nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, - global_step, last_layer=None, cond=None, split="train", predicted_indices=None): - if not exists(codebook_loss): - codebook_loss = torch.tensor([0.]).to(inputs.device) - #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) - rec_loss = rec_loss + self.perceptual_weight * p_loss - else: - p_loss = torch.tensor([0.0]) - - nll_loss = rec_loss - #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - nll_loss = torch.mean(nll_loss) - - # now the GAN part - if optimizer_idx == 0: - # generator update - if cond is None: - assert not self.disc_conditional - logits_fake = self.discriminator(reconstructions.contiguous()) - else: - assert self.disc_conditional - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) - g_loss = -torch.mean(logits_fake) - - try: - d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) - except RuntimeError: - assert not self.training - d_weight = torch.tensor(0.0) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() - - log = {"{}/total_loss".format(split): loss.clone().detach().mean(), - "{}/quant_loss".format(split): codebook_loss.detach().mean(), - "{}/nll_loss".format(split): nll_loss.detach().mean(), - "{}/rec_loss".format(split): rec_loss.detach().mean(), - "{}/p_loss".format(split): p_loss.detach().mean(), - "{}/d_weight".format(split): d_weight.detach(), - "{}/disc_factor".format(split): torch.tensor(disc_factor), - "{}/g_loss".format(split): g_loss.detach().mean(), - } - if predicted_indices is not None: - assert self.n_classes is not None - with torch.no_grad(): - perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) - log[f"{split}/perplexity"] = perplexity - log[f"{split}/cluster_usage"] = cluster_usage - return loss, log - - if optimizer_idx == 1: - # second pass for discriminator update - if cond is None: - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - else: - logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) - logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) - - disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) - - log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), - "{}/logits_real".format(split): logits_real.detach().mean(), - "{}/logits_fake".format(split): logits_fake.detach().mean() - } - return d_loss, log diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py deleted file mode 100644 index 5fc15bf9..00000000 --- a/ldm/modules/x_transformer.py +++ /dev/null @@ -1,641 +0,0 @@ -"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" -import torch -from torch import nn, einsum -import torch.nn.functional as F -from functools import partial -from inspect import isfunction -from collections import namedtuple -from einops import rearrange, repeat, reduce - -# constants - -DEFAULT_DIM_HEAD = 64 - -Intermediates = namedtuple('Intermediates', [ - 'pre_softmax_attn', - 'post_softmax_attn' -]) - -LayerIntermediates = namedtuple('Intermediates', [ - 'hiddens', - 'attn_intermediates' -]) - - -class AbsolutePositionalEmbedding(nn.Module): - def __init__(self, dim, max_seq_len): - super().__init__() - self.emb = nn.Embedding(max_seq_len, dim) - self.init_() - - def init_(self): - nn.init.normal_(self.emb.weight, std=0.02) - - def forward(self, x): - n = torch.arange(x.shape[1], device=x.device) - return self.emb(n)[None, :, :] - - -class FixedPositionalEmbedding(nn.Module): - def __init__(self, dim): - super().__init__() - inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) - - def forward(self, x, seq_dim=1, offset=0): - t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset - sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) - emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) - return emb[None, :, :] - - -# helpers - -def exists(val): - return val is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def always(val): - def inner(*args, **kwargs): - return val - return inner - - -def not_equals(val): - def inner(x): - return x != val - return inner - - -def equals(val): - def inner(x): - return x == val - return inner - - -def max_neg_value(tensor): - return -torch.finfo(tensor.dtype).max - - -# keyword argument helpers - -def pick_and_pop(keys, d): - values = list(map(lambda key: d.pop(key), keys)) - return dict(zip(keys, values)) - - -def group_dict_by_key(cond, d): - return_val = [dict(), dict()] - for key in d.keys(): - match = bool(cond(key)) - ind = int(not match) - return_val[ind][key] = d[key] - return (*return_val,) - - -def string_begins_with(prefix, str): - return str.startswith(prefix) - - -def group_by_key_prefix(prefix, d): - return group_dict_by_key(partial(string_begins_with, prefix), d) - - -def groupby_prefix_and_trim(prefix, d): - kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) - return kwargs_without_prefix, kwargs - - -# classes -class Scale(nn.Module): - def __init__(self, value, fn): - super().__init__() - self.value = value - self.fn = fn - - def forward(self, x, **kwargs): - x, *rest = self.fn(x, **kwargs) - return (x * self.value, *rest) - - -class Rezero(nn.Module): - def __init__(self, fn): - super().__init__() - self.fn = fn - self.g = nn.Parameter(torch.zeros(1)) - - def forward(self, x, **kwargs): - x, *rest = self.fn(x, **kwargs) - return (x * self.g, *rest) - - -class ScaleNorm(nn.Module): - def __init__(self, dim, eps=1e-5): - super().__init__() - self.scale = dim ** -0.5 - self.eps = eps - self.g = nn.Parameter(torch.ones(1)) - - def forward(self, x): - norm = torch.norm(x, dim=-1, keepdim=True) * self.scale - return x / norm.clamp(min=self.eps) * self.g - - -class RMSNorm(nn.Module): - def __init__(self, dim, eps=1e-8): - super().__init__() - self.scale = dim ** -0.5 - self.eps = eps - self.g = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - norm = torch.norm(x, dim=-1, keepdim=True) * self.scale - return x / norm.clamp(min=self.eps) * self.g - - -class Residual(nn.Module): - def forward(self, x, residual): - return x + residual - - -class GRUGating(nn.Module): - def __init__(self, dim): - super().__init__() - self.gru = nn.GRUCell(dim, dim) - - def forward(self, x, residual): - gated_output = self.gru( - rearrange(x, 'b n d -> (b n) d'), - rearrange(residual, 'b n d -> (b n) d') - ) - - return gated_output.reshape_as(x) - - -# feedforward - -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -# attention. -class Attention(nn.Module): - def __init__( - self, - dim, - dim_head=DEFAULT_DIM_HEAD, - heads=8, - causal=False, - mask=None, - talking_heads=False, - sparse_topk=None, - use_entmax15=False, - num_mem_kv=0, - dropout=0., - on_attn=False - ): - super().__init__() - if use_entmax15: - raise NotImplementedError("Check out entmax activation instead of softmax activation!") - self.scale = dim_head ** -0.5 - self.heads = heads - self.causal = causal - self.mask = mask - - inner_dim = dim_head * heads - - self.to_q = nn.Linear(dim, inner_dim, bias=False) - self.to_k = nn.Linear(dim, inner_dim, bias=False) - self.to_v = nn.Linear(dim, inner_dim, bias=False) - self.dropout = nn.Dropout(dropout) - - # talking heads - self.talking_heads = talking_heads - if talking_heads: - self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) - self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) - - # explicit topk sparse attention - self.sparse_topk = sparse_topk - - # entmax - #self.attn_fn = entmax15 if use_entmax15 else F.softmax - self.attn_fn = F.softmax - - # add memory key / values - self.num_mem_kv = num_mem_kv - if num_mem_kv > 0: - self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) - self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) - - # attention on attention - self.attn_on_attn = on_attn - self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) - - def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - rel_pos=None, - sinusoidal_emb=None, - prev_attn=None, - mem=None - ): - b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device - kv_input = default(context, x) - - q_input = x - k_input = kv_input - v_input = kv_input - - if exists(mem): - k_input = torch.cat((mem, k_input), dim=-2) - v_input = torch.cat((mem, v_input), dim=-2) - - if exists(sinusoidal_emb): - # in shortformer, the query would start at a position offset depending on the past cached memory - offset = k_input.shape[-2] - q_input.shape[-2] - q_input = q_input + sinusoidal_emb(q_input, offset=offset) - k_input = k_input + sinusoidal_emb(k_input) - - q = self.to_q(q_input) - k = self.to_k(k_input) - v = self.to_v(v_input) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) - - input_mask = None - if any(map(exists, (mask, context_mask))): - q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) - k_mask = q_mask if not exists(context) else context_mask - k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) - q_mask = rearrange(q_mask, 'b i -> b () i ()') - k_mask = rearrange(k_mask, 'b j -> b () () j') - input_mask = q_mask * k_mask - - if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) - k = torch.cat((mem_k, k), dim=-2) - v = torch.cat((mem_v, v), dim=-2) - if exists(input_mask): - input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) - - dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale - mask_value = max_neg_value(dots) - - if exists(prev_attn): - dots = dots + prev_attn - - pre_softmax_attn = dots - - if talking_heads: - dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() - - if exists(rel_pos): - dots = rel_pos(dots) - - if exists(input_mask): - dots.masked_fill_(~input_mask, mask_value) - del input_mask - - if self.causal: - i, j = dots.shape[-2:] - r = torch.arange(i, device=device) - mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') - mask = F.pad(mask, (j - i, 0), value=False) - dots.masked_fill_(mask, mask_value) - del mask - - if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: - top, _ = dots.topk(self.sparse_topk, dim=-1) - vk = top[..., -1].unsqueeze(-1).expand_as(dots) - mask = dots < vk - dots.masked_fill_(mask, mask_value) - del mask - - attn = self.attn_fn(dots, dim=-1) - post_softmax_attn = attn - - attn = self.dropout(attn) - - if talking_heads: - attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() - - out = einsum('b h i j, b h j d -> b h i d', attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - - intermediates = Intermediates( - pre_softmax_attn=pre_softmax_attn, - post_softmax_attn=post_softmax_attn - ) - - return self.to_out(out), intermediates - - -class AttentionLayers(nn.Module): - def __init__( - self, - dim, - depth, - heads=8, - causal=False, - cross_attend=False, - only_cross=False, - use_scalenorm=False, - use_rmsnorm=False, - use_rezero=False, - rel_pos_num_buckets=32, - rel_pos_max_distance=128, - position_infused_attn=False, - custom_layers=None, - sandwich_coef=None, - par_ratio=None, - residual_attn=False, - cross_residual_attn=False, - macaron=False, - pre_norm=True, - gate_residual=False, - **kwargs - ): - super().__init__() - ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) - attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) - - dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) - - self.dim = dim - self.depth = depth - self.layers = nn.ModuleList([]) - - self.has_pos_emb = position_infused_attn - self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None - self.rotary_pos_emb = always(None) - - assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' - self.rel_pos = None - - self.pre_norm = pre_norm - - self.residual_attn = residual_attn - self.cross_residual_attn = cross_residual_attn - - norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm - norm_class = RMSNorm if use_rmsnorm else norm_class - norm_fn = partial(norm_class, dim) - - norm_fn = nn.Identity if use_rezero else norm_fn - branch_fn = Rezero if use_rezero else None - - if cross_attend and not only_cross: - default_block = ('a', 'c', 'f') - elif cross_attend and only_cross: - default_block = ('c', 'f') - else: - default_block = ('a', 'f') - - if macaron: - default_block = ('f',) + default_block - - if exists(custom_layers): - layer_types = custom_layers - elif exists(par_ratio): - par_depth = depth * len(default_block) - assert 1 < par_ratio <= par_depth, 'par ratio out of range' - default_block = tuple(filter(not_equals('f'), default_block)) - par_attn = par_depth // par_ratio - depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper - par_width = (depth_cut + depth_cut // par_attn) // par_attn - assert len(default_block) <= par_width, 'default block is too large for par_ratio' - par_block = default_block + ('f',) * (par_width - len(default_block)) - par_head = par_block * par_attn - layer_types = par_head + ('f',) * (par_depth - len(par_head)) - elif exists(sandwich_coef): - assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' - layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef - else: - layer_types = default_block * depth - - self.layer_types = layer_types - self.num_attn_layers = len(list(filter(equals('a'), layer_types))) - - for layer_type in self.layer_types: - if layer_type == 'a': - layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) - elif layer_type == 'c': - layer = Attention(dim, heads=heads, **attn_kwargs) - elif layer_type == 'f': - layer = FeedForward(dim, **ff_kwargs) - layer = layer if not macaron else Scale(0.5, layer) - else: - raise Exception(f'invalid layer type {layer_type}') - - if isinstance(layer, Attention) and exists(branch_fn): - layer = branch_fn(layer) - - if gate_residual: - residual_fn = GRUGating(dim) - else: - residual_fn = Residual() - - self.layers.append(nn.ModuleList([ - norm_fn(), - layer, - residual_fn - ])) - - def forward( - self, - x, - context=None, - mask=None, - context_mask=None, - mems=None, - return_hiddens=False - ): - hiddens = [] - intermediates = [] - prev_attn = None - prev_cross_attn = None - - mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers - - for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): - is_last = ind == (len(self.layers) - 1) - - if layer_type == 'a': - hiddens.append(x) - layer_mem = mems.pop(0) - - residual = x - - if self.pre_norm: - x = norm(x) - - if layer_type == 'a': - out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, - prev_attn=prev_attn, mem=layer_mem) - elif layer_type == 'c': - out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) - elif layer_type == 'f': - out = block(x) - - x = residual_fn(out, residual) - - if layer_type in ('a', 'c'): - intermediates.append(inter) - - if layer_type == 'a' and self.residual_attn: - prev_attn = inter.pre_softmax_attn - elif layer_type == 'c' and self.cross_residual_attn: - prev_cross_attn = inter.pre_softmax_attn - - if not self.pre_norm and not is_last: - x = norm(x) - - if return_hiddens: - intermediates = LayerIntermediates( - hiddens=hiddens, - attn_intermediates=intermediates - ) - - return x, intermediates - - return x - - -class Encoder(AttentionLayers): - def __init__(self, **kwargs): - assert 'causal' not in kwargs, 'cannot set causality on encoder' - super().__init__(causal=False, **kwargs) - - - -class TransformerWrapper(nn.Module): - def __init__( - self, - *, - num_tokens, - max_seq_len, - attn_layers, - emb_dim=None, - max_mem_len=0., - emb_dropout=0., - num_memory_tokens=None, - tie_embedding=False, - use_pos_emb=True - ): - super().__init__() - assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' - - dim = attn_layers.dim - emb_dim = default(emb_dim, dim) - - self.max_seq_len = max_seq_len - self.max_mem_len = max_mem_len - self.num_tokens = num_tokens - - self.token_emb = nn.Embedding(num_tokens, emb_dim) - self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( - use_pos_emb and not attn_layers.has_pos_emb) else always(0) - self.emb_dropout = nn.Dropout(emb_dropout) - - self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() - self.attn_layers = attn_layers - self.norm = nn.LayerNorm(dim) - - self.init_() - - self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() - - # memory tokens (like [cls]) from Memory Transformers paper - num_memory_tokens = default(num_memory_tokens, 0) - self.num_memory_tokens = num_memory_tokens - if num_memory_tokens > 0: - self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) - - # let funnel encoder know number of memory tokens, if specified - if hasattr(attn_layers, 'num_memory_tokens'): - attn_layers.num_memory_tokens = num_memory_tokens - - def init_(self): - nn.init.normal_(self.token_emb.weight, std=0.02) - - def forward( - self, - x, - return_embeddings=False, - mask=None, - return_mems=False, - return_attn=False, - mems=None, - **kwargs - ): - b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens - x = self.token_emb(x) - x += self.pos_emb(x) - x = self.emb_dropout(x) - - x = self.project_emb(x) - - if num_mem > 0: - mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) - x = torch.cat((mem, x), dim=1) - - # auto-handle masking after appending memory tokens - if exists(mask): - mask = F.pad(mask, (num_mem, 0), value=True) - - x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) - x = self.norm(x) - - mem, x = x[:, :num_mem], x[:, num_mem:] - - out = self.to_logits(x) if not return_embeddings else x - - if return_mems: - hiddens = intermediates.hiddens - new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens - new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) - return out, new_mems - - if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) - return out, attn_maps - - return out - diff --git a/ldm/util.py b/ldm/util.py deleted file mode 100644 index 8ba38853..00000000 --- a/ldm/util.py +++ /dev/null @@ -1,203 +0,0 @@ -import importlib - -import torch -import numpy as np -from collections import abc -from einops import rearrange -from functools import partial - -import multiprocessing as mp -from threading import Thread -from queue import Queue - -from inspect import isfunction -from PIL import Image, ImageDraw, ImageFont - - -def log_txt_as_img(wh, xc, size=10): - # wh a tuple of (width, height) - # xc a list of captions to plot - b = len(xc) - txts = list() - for bi in range(b): - txt = Image.new("RGB", wh, color="white") - draw = ImageDraw.Draw(txt) - font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) - nc = int(40 * (wh[0] / 256)) - lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) - - try: - draw.text((0, 0), lines, fill="black", font=font) - except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") - - txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 - txts.append(txt) - txts = np.stack(txts) - txts = torch.tensor(txts) - return txts - - -def ismap(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] > 3) - - -def isimage(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) - - -def exists(x): - return x is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def mean_flat(tensor): - """ - https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def count_params(model, verbose=False): - total_params = sum(p.numel() for p in model.parameters()) - if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") - return total_params - - -def instantiate_from_config(config): - if not "target" in config: - if config == '__is_first_stage__': - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - -def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): - # create dummy dataset instance - - # run prefetching - if idx_to_fn: - res = func(data, worker_id=idx) - else: - res = func(data) - Q.put([idx, res]) - Q.put("Done") - - -def parallel_data_prefetch( - func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False -): - # if target_data_type not in ["ndarray", "list"]: - # raise ValueError( - # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." - # ) - if isinstance(data, np.ndarray) and target_data_type == "list": - raise ValueError("list expected but function got ndarray.") - elif isinstance(data, abc.Iterable): - if isinstance(data, dict): - print( - f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' - ) - data = list(data.values()) - if target_data_type == "ndarray": - data = np.asarray(data) - else: - data = list(data) - else: - raise TypeError( - f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." - ) - - if cpu_intensive: - Q = mp.Queue(1000) - proc = mp.Process - else: - Q = Queue(1000) - proc = Thread - # spawn processes - if target_data_type == "ndarray": - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate(np.array_split(data, n_proc)) - ] - else: - step = ( - int(len(data) / n_proc + 1) - if len(data) % n_proc != 0 - else int(len(data) / n_proc) - ) - arguments = [ - [func, Q, part, i, use_worker_id] - for i, part in enumerate( - [data[i: i + step] for i in range(0, len(data), step)] - ) - ] - processes = [] - for i in range(n_proc): - p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) - processes += [p] - - # start processes - print(f"Start prefetching...") - import time - - start = time.time() - gather_res = [[] for _ in range(n_proc)] - try: - for p in processes: - p.start() - - k = 0 - while k < n_proc: - # get result - res = Q.get() - if res == "Done": - k += 1 - else: - gather_res[res[0]] = res[1] - - except Exception as e: - print("Exception: ", e) - for p in processes: - p.terminate() - - raise e - finally: - for p in processes: - p.join() - print(f"Prefetching complete. [{time.time() - start} sec.]") - - if target_data_type == 'ndarray': - if not isinstance(gather_res[0], np.ndarray): - return np.concatenate([np.asarray(r) for r in gather_res], axis=0) - - # order outputs - return np.concatenate(gather_res, axis=0) - elif target_data_type == 'list': - out = [] - for r in gather_res: - out.extend(r) - return out - else: - return gather_res diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3ec3f98a..edb8b420 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -28,7 +28,7 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At # new memory efficient cross attention blocks do not support hypernets and we already # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention -ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention +# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 ldm.modules.attention.print = lambda *args: None @@ -82,7 +82,12 @@ class StableDiffusionModelHijack: def hijack(self, m): - if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: + if shared.text_model_name == "XLMR-Large": + model_embeddings = m.cond_stage_model.roberta.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) @@ -91,11 +96,7 @@ class StableDiffusionModelHijack: m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) apply_optimizations() - elif shared.text_model_name == "XLMR-Large": - model_embeddings = m.cond_stage_model.roberta.embeddings - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - + self.clip = m.cond_stage_model fix_checkpoint() diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index b451d1cf..9ea6e1ce 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -4,7 +4,7 @@ import torch from modules import prompt_parser, devices from modules.shared import opts - +import modules.shared as shared def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 @@ -177,6 +177,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): + if shared.text_model_name == "XLMR-Large": + return self.wrapped.encode(text) + use_old = opts.use_old_emphasis_implementation if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) @@ -254,7 +257,10 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) self.tokenizer = wrapped.tokenizer - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + if shared.text_model_name == "XLMR-Large": + self.comma_token = None + else : + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] self.token_mults = {} tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] diff --git a/modules/xlmr.py b/modules/xlmr.py new file mode 100644 index 00000000..beab3fdf --- /dev/null +++ b/modules/xlmr.py @@ -0,0 +1,137 @@ +from transformers import BertPreTrainedModel,BertModel,BertConfig +import torch.nn as nn +import torch +from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig +from transformers import XLMRobertaModel,XLMRobertaTokenizer +from typing import Optional + +class BertSeriesConfig(BertConfig): + def __init__(self, vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None,project_dim=512, pooler_fn="average",learn_encoder=False,model_type='bert',**kwargs): + + super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, classifier_dropout, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + +class RobertaSeriesConfig(XLMRobertaConfig): + def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2,project_dim=512,pooler_fn='cls',learn_encoder=False, **kwargs): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + self.project_dim = project_dim + self.pooler_fn = pooler_fn + self.learn_encoder = learn_encoder + + +class BertSeriesModelWithTransformation(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + config_class = BertSeriesConfig + + def __init__(self, config=None, **kargs): + # modify initialization for autoloading + if config is None: + config = XLMRobertaConfig() + config.attention_probs_dropout_prob= 0.1 + config.bos_token_id=0 + config.eos_token_id=2 + config.hidden_act='gelu' + config.hidden_dropout_prob=0.1 + config.hidden_size=1024 + config.initializer_range=0.02 + config.intermediate_size=4096 + config.layer_norm_eps=1e-05 + config.max_position_embeddings=514 + + config.num_attention_heads=16 + config.num_hidden_layers=24 + config.output_past=True + config.pad_token_id=1 + config.position_embedding_type= "absolute" + + config.type_vocab_size= 1 + config.use_cache=True + config.vocab_size= 250002 + config.project_dim = 768 + config.learn_encoder = False + super().__init__(config) + self.roberta = XLMRobertaModel(config) + self.transformation = nn.Linear(config.hidden_size,config.project_dim) + self.pre_LN=nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') + self.pooler = lambda x: x[:,0] + self.post_init() + + def encode(self,c): + device = next(self.parameters()).device + text = self.tokenizer(c, + truncation=True, + max_length=77, + return_length=False, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt") + text["input_ids"] = torch.tensor(text["input_ids"]).to(device) + text["attention_mask"] = torch.tensor( + text['attention_mask']).to(device) + features = self(**text) + return features['projection_state'] + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) : + r""" + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=return_dict, + ) + + # last module outputs + sequence_output = outputs[0] + + + # project every module + sequence_output_ln = self.pre_LN(sequence_output) + + # pooler + pooler_output = self.pooler(sequence_output_ln) + pooler_output = self.transformation(pooler_output) + projection_state = self.transformation(outputs.last_hidden_state) + + return { + 'pooler_output':pooler_output, + 'last_hidden_state':outputs.last_hidden_state, + 'hidden_states':outputs.hidden_states, + 'attentions':outputs.attentions, + 'projection_state':projection_state, + 'sequence_out': sequence_output + } + + +class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): + base_model_prefix = 'roberta' + config_class= RobertaSeriesConfig \ No newline at end of file -- cgit v1.2.3 From 9c86fb8cace6d8ac0843e0ddad0ba5ae7f3148c9 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Fri, 2 Dec 2022 16:08:46 +0800 Subject: fix bug Signed-off-by: zhaohu xing <920232796@qq.com> --- modules/shared.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 1408dee3..ac7678c3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -111,7 +111,11 @@ restricted_opts = { from omegaconf import OmegaConf config = OmegaConf.load(f"{cmd_opts.config}") # XLMR-Large -text_model_name = config.model.params.cond_stage_config.params.name +try: + text_model_name = config.model.params.cond_stage_config.params.name + +except : + text_model_name = "stable_diffusion" cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access -- cgit v1.2.3 From 4929503258d80abbc4b5f40da034298fe3803906 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 6 Dec 2022 09:03:55 +0800 Subject: fix bugs Signed-off-by: zhaohu xing <920232796@qq.com> --- modules/devices.py | 4 ++-- modules/sd_hijack.py | 2 +- v2-inference.yaml | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+), 3 deletions(-) create mode 100644 v2-inference.yaml diff --git a/modules/devices.py b/modules/devices.py index e69c1fe3..f00079c6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -38,8 +38,8 @@ def get_optimal_device(): if torch.cuda.is_available(): return torch.device(get_cuda_device_string()) - # if has_mps(): - # return torch.device("mps") + if has_mps(): + return torch.device("mps") return cpu diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index edb8b420..cd65d356 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -28,7 +28,7 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At # new memory efficient cross attention blocks do not support hypernets and we already # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention -# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention +ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 ldm.modules.attention.print = lambda *args: None diff --git a/v2-inference.yaml b/v2-inference.yaml new file mode 100644 index 00000000..0eb25395 --- /dev/null +++ b/v2-inference.yaml @@ -0,0 +1,67 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" \ No newline at end of file -- cgit v1.2.3 From 5dcc22606d05ebe5ae89c990bd83a3eb068fcb78 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 6 Dec 2022 16:04:50 +0800 Subject: add hash and fix undo hijack bug Signed-off-by: zhaohu xing <920232796@qq.com> --- .DS_Store | Bin 0 -> 6148 bytes launch.py | 10 ++++---- modules/sd_hijack.py | 6 ++++- v2-inference-v.yaml | 68 +++++++++++++++++++++++++++++++++++++++++++++++++++ v2-inference.yaml | 67 -------------------------------------------------- 5 files changed, 78 insertions(+), 73 deletions(-) create mode 100644 .DS_Store create mode 100644 v2-inference-v.yaml delete mode 100644 v2-inference.yaml diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..5008ddfc Binary files /dev/null and b/.DS_Store differ diff --git a/launch.py b/launch.py index 0d8f2776..0e1bbaf2 100644 --- a/launch.py +++ b/launch.py @@ -234,11 +234,11 @@ def prepare_enviroment(): os.makedirs(dir_repos, exist_ok=True) - git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", ) - git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", ) - git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", ) - git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", ) - git_clone(blip_repo, repo_dir('BLIP'), "BLIP", ) + git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) + git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) + git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) + git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) + git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) if not is_installed("lpips"): run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 9b5890e7..9fed1b6f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -112,7 +112,11 @@ class StableDiffusionModelHijack: self.layers = flatten(m) def undo_hijack(self, m): - if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: + + if shared.text_model_name == "XLMR-Large": + m.cond_stage_model = m.cond_stage_model.wrapped + + elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped model_embeddings = m.cond_stage_model.transformer.text_model.embeddings diff --git a/v2-inference-v.yaml b/v2-inference-v.yaml new file mode 100644 index 00000000..513cd635 --- /dev/null +++ b/v2-inference-v.yaml @@ -0,0 +1,68 @@ +model: + base_learning_rate: 1.0e-4 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False # we set this to false because this is an inference only config + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + use_fp16: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" \ No newline at end of file diff --git a/v2-inference.yaml b/v2-inference.yaml deleted file mode 100644 index 0eb25395..00000000 --- a/v2-inference.yaml +++ /dev/null @@ -1,67 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - use_fp16: True - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" \ No newline at end of file -- cgit v1.2.3 From 965fc5ac5a6ccdf38342e21c97183011a04e799e Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 6 Dec 2022 16:15:15 +0800 Subject: delete a file Signed-off-by: zhaohu xing <920232796@qq.com> --- .DS_Store | Bin 6148 -> 0 bytes modules/shared.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 .DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 5008ddfc..00000000 Binary files a/.DS_Store and /dev/null differ diff --git a/modules/shared.py b/modules/shared.py index 522c56c1..8419b531 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -22,7 +22,7 @@ demo = None sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default="configs/altdiffusion/ad-inference.yaml", help="path to config which constructs model",) +parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) -- cgit v1.2.3 From 9fd457e21d6c809a69a1318f03d75f7b3e09b865 Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Thu, 15 Dec 2022 21:57:48 +0300 Subject: allow_credentials and allow_headers for api from https://fastapi.tiangolo.com/tutorial/cors/ --- webui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/webui.py b/webui.py index c2d0c6be..13a4d14a 100644 --- a/webui.py +++ b/webui.py @@ -90,11 +90,11 @@ def initialize(): def setup_cors(app): if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: - app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) elif cmd_opts.cors_allow_origins: - app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*']) elif cmd_opts.cors_allow_origins_regex: - app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) + app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) def create_api(app): -- cgit v1.2.3 From f23a822f1c9cb3bd2e8772c75af429e06515eaef Mon Sep 17 00:00:00 2001 From: Philpax Date: Sat, 24 Dec 2022 20:45:16 +1100 Subject: feat(api): include job_timestamp in progress --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared.py b/modules/shared.py index 8ea3b441..f356dbf7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -171,6 +171,7 @@ class State: "interrupted": self.skipped, "job": self.job, "job_count": self.job_count, + "job_timestamp": self.job_timestamp, "job_no": self.job_no, "sampling_step": self.sampling_step, "sampling_steps": self.sampling_steps, -- cgit v1.2.3 From fa931733f6acc94e058a1d3d4655846e33ae34be Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 25 Dec 2022 20:17:49 +1100 Subject: fix(api): assign sd_model after settings change --- modules/api/api.py | 2 -- modules/processing.py | 6 ++++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 1ceba75d..0a1a1905 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -121,7 +121,6 @@ class Api: def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): populate = txt2imgreq.copy(update={ # Override __init__ params - "sd_model": shared.sd_model, "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True @@ -153,7 +152,6 @@ class Api: mask = decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params - "sd_model": shared.sd_model, "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, diff --git a/modules/processing.py b/modules/processing.py index 4a406084..0b270278 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -50,9 +50,9 @@ def apply_color_correction(correction, original_image): correction, channel_axis=2 ), cv2.COLOR_LAB2RGB).astype("uint8")) - + image = blendLayers(image, original_image, BlendType.LUMINOSITY) - + return image @@ -466,6 +466,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE + # Assign sd_model here to ensure that it reflects the model after any changes + p.sd_model = shared.sd_model res = process_images_inner(p) finally: -- cgit v1.2.3 From 5be9387b230794a8c771120577cb213490c940c0 Mon Sep 17 00:00:00 2001 From: Philpax Date: Sun, 25 Dec 2022 21:45:44 +1100 Subject: fix(api): only begin/end state in lock --- modules/api/api.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 1ceba75d..59b81c93 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -130,14 +130,12 @@ class Api: if populate.sampler_name: populate.sampler_index = None # prevent a warning later on p = StableDiffusionProcessingTxt2Img(**vars(populate)) - # Override object param - - shared.state.begin() with self.queue_lock: + shared.state.begin() processed = process_images(p) + shared.state.end() - shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) @@ -169,12 +167,10 @@ class Api: p.init_images = [decode_base64_to_image(x) for x in init_images] - shared.state.begin() - with self.queue_lock: + shared.state.begin() processed = process_images(p) - - shared.state.end() + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) -- cgit v1.2.3 From 893933e05ad267778111b4fad6d1ecb80937afdf Mon Sep 17 00:00:00 2001 From: hitomi Date: Sun, 25 Dec 2022 20:49:25 +0800 Subject: Add memory cache for VAE weights --- modules/sd_vae.py | 31 +++++++++++++++++++++++++------ modules/shared.py | 1 + 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 3856418e..ac71d62d 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,5 +1,6 @@ import torch import os +import collections from collections import namedtuple from modules import shared, devices, script_callbacks from modules.paths import models_path @@ -30,6 +31,7 @@ base_vae = None loaded_vae_file = None checkpoint_info = None +checkpoints_loaded = collections.OrderedDict() def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: @@ -149,13 +151,30 @@ def load_vae(model, vae_file=None): global first_load, vae_dict, vae_list, loaded_vae_file # save_settings = False + cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 + if vae_file: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") - store_base_vae(model) - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - _load_vae_dict(model, vae_dict_1) + if cache_enabled and vae_file in checkpoints_loaded: + # use vae checkpoint cache + print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + store_base_vae(model) + _load_vae_dict(model, checkpoints_loaded[vae_file]) + else: + assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" + print(f"Loading VAE weights from: {vae_file}") + store_base_vae(model) + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + _load_vae_dict(model, vae_dict_1) + + if cache_enabled: + # cache newly loaded vae + checkpoints_loaded[vae_file] = vae_dict_1.copy() + + # clean up cache if limit is reached + if cache_enabled: + while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model + checkpoints_loaded.popitem(last=False) # LRU # If vae used is not in dict, update it # It will be removed on refresh though diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..671d30e1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), -- cgit v1.2.3 From 4af3ca5393151d61363c30eef4965e694eeac15e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 26 Dec 2022 10:11:28 +0300 Subject: make it so that blank ENSD does not break image generation --- modules/processing.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 4a406084..0a9a8f95 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -338,13 +338,14 @@ def slerp(val, low, high): def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): + eta_noise_seed_delta = opts.eta_noise_seed_delta or 0 xs = [] # if we have multiple seeds, this means we are working with batch size>1; this then # enables the generation of additional tensors with noise that the sampler will use during its processing. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): + if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0): sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None @@ -384,8 +385,8 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see if sampler_noises is not None: cnt = p.sampler.number_of_needed_noises(p) - if opts.eta_noise_seed_delta > 0: - torch.manual_seed(seed + opts.eta_noise_seed_delta) + if eta_noise_seed_delta > 0: + torch.manual_seed(seed + eta_noise_seed_delta) for j in range(cnt): sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) -- cgit v1.2.3 From ae955b0146a52ea2474c79655ede0d361829ef63 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 26 Dec 2022 09:53:26 -0500 Subject: fix rgba to rgb when using jpeg output --- modules/images.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/images.py b/modules/images.py index 31d4528d..962a955d 100644 --- a/modules/images.py +++ b/modules/images.py @@ -525,6 +525,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality, pnginfo=pnginfo_data) elif extension.lower() in (".jpg", ".jpeg", ".webp"): + if image_to_save.mode == 'RGBA': + image_to_save = image_to_save.convert("RGB") + image_to_save.save(temp_file_path, format=image_format, quality=opts.jpeg_quality) if opts.enable_pnginfo and info is not None: -- cgit v1.2.3 From 4df5009acb6832daef1ff5949404b5aadc8f8fa4 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 26 Dec 2022 20:49:13 +0000 Subject: Update sd_samplers.py --- modules/sd_samplers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 177b5338..f4473832 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -462,6 +462,9 @@ class KDiffusionSampler: return extra_params_kwargs def get_sigmas(self, p, steps): + disc = opts.always_discard_next_to_last_sigma or (self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)) + steps += 1 if disc else 0 + if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': @@ -469,7 +472,7 @@ class KDiffusionSampler: else: sigmas = self.model_wrap.get_sigmas(steps) - if self.config is not None and self.config.options.get('discard_next_to_last_sigma', False): + if disc: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) return sigmas -- cgit v1.2.3 From 03f486a2399df0a2b24c7aeea72e64f106a87297 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 26 Dec 2022 20:49:33 +0000 Subject: Update shared.py --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..5edb316c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -418,6 +418,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), + 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), })) options_templates.update(options_section((None, "Hidden options"), { -- cgit v1.2.3 From 5ba04f9ec050a66e918571f07e8863f157f05b44 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Dec 2022 13:45:58 +0100 Subject: Attempting to solve slow loads for `safetensors`. Fixes #5893 --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index ecdd91c5..cd938656 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -168,7 +168,10 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) + device = map_location or shared.weight_load_location + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.3 From 5a523d13050a5ede43c473767f29dfe2e391136a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 27 Dec 2022 11:27:40 +0100 Subject: Version 0.2.7 Fixes Windows SAFETENSORS_FAST_GPU path. --- requirements_versions.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_versions.txt b/requirements_versions.txt index c126c8c4..52e98818 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -26,5 +26,5 @@ lark==1.1.2 inflection==0.5.1 GitPython==3.1.27 torchsde==0.2.5 -safetensors==0.2.5 +safetensors==0.2.7 httpcore<=0.15 -- cgit v1.2.3 From 5958bbd244703f7c248a91e86dea5d52acc85505 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 30 Dec 2022 19:36:36 -0500 Subject: add additional memory states --- modules/memmon.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/memmon.py b/modules/memmon.py index 9fb9b687..a7060f58 100644 --- a/modules/memmon.py +++ b/modules/memmon.py @@ -71,10 +71,13 @@ class MemUsageMonitor(threading.Thread): def read(self): if not self.disabled: free, total = torch.cuda.mem_get_info() + self.data["free"] = free self.data["total"] = total torch_stats = torch.cuda.memory_stats(self.device) + self.data["active"] = torch_stats["active.all.current"] self.data["active_peak"] = torch_stats["active_bytes.all.peak"] + self.data["reserved"] = torch_stats["reserved_bytes.all.current"] self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] self.data["system_peak"] = total - self.data["min_free"] -- cgit v1.2.3 From d3aa2a48e1e896b6ffafda5367200a4bbd46b0d7 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 30 Dec 2022 19:38:53 -0500 Subject: remove unnecessary console message --- modules/sd_hijack_inpainting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index bb5499b3..06b75772 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -178,7 +178,7 @@ def sample_plms(self, # sampling C, H, W = shape size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') + # print(f'Data shape for PLMS sampling is {size}') # remove unnecessary message samples, intermediates = self.plms_sampling(conditioning, size, callback=callback, -- cgit v1.2.3 From 463048344fc036b262aa132584b65ee6e9fec6cf Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Fri, 30 Dec 2022 19:41:47 -0500 Subject: fix shared state dictionary --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..9a13fb60 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -168,7 +168,7 @@ class State: def dict(self): obj = { "skipped": self.skipped, - "interrupted": self.skipped, + "interrupted": self.interrupted, "job": self.job, "job_count": self.job_count, "job_no": self.job_no, -- cgit v1.2.3 From fef98723b2b1c7a9893ead41bbefcb36192babd6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 31 Dec 2022 12:44:26 +0300 Subject: set sd_model for API later, inside the lock, to prevent multiple requests with different models ending up with incorrect results #5877 #6012 --- modules/api/api.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 59b81c93..11daff0d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -121,7 +121,6 @@ class Api: def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): populate = txt2imgreq.copy(update={ # Override __init__ params - "sd_model": shared.sd_model, "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True @@ -129,9 +128,10 @@ class Api: ) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on - p = StableDiffusionProcessingTxt2Img(**vars(populate)) with self.queue_lock: + p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate)) + shared.state.begin() processed = process_images(p) shared.state.end() @@ -151,7 +151,6 @@ class Api: mask = decode_base64_to_image(mask) populate = img2imgreq.copy(update={ # Override __init__ params - "sd_model": shared.sd_model, "sampler_name": validate_sampler_name(img2imgreq.sampler_name or img2imgreq.sampler_index), "do_not_save_samples": True, "do_not_save_grid": True, @@ -163,11 +162,11 @@ class Api: args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. - p = StableDiffusionProcessingImg2Img(**args) - - p.init_images = [decode_base64_to_image(x) for x in init_images] with self.queue_lock: + p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) + p.init_images = [decode_base64_to_image(x) for x in init_images] + shared.state.begin() processed = process_images(p) shared.state.end() -- cgit v1.2.3 From 65be1df7bb55b21a3d76630a397c820218cbd12a Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 31 Dec 2022 07:46:04 -0500 Subject: initialize result so not to cause exception on empty results --- modules/interrogate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 46935210..6f761c5a 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -135,7 +135,7 @@ class InterrogateModels: return caption[0] def interrogate(self, pil_image): - res = None + res = "" try: -- cgit v1.2.3 From f34c7341720fb2059992926c9f9ae6ff25f7385b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 31 Dec 2022 18:06:35 +0300 Subject: alt-diffusion integration --- configs/alt-diffusion-inference.yaml | 72 ++++++++++++++++++++++++++++++++++ configs/altdiffusion/ad-inference.yaml | 72 ---------------------------------- configs/v1-inference.yaml | 70 +++++++++++++++++++++++++++++++++ modules/sd_hijack.py | 18 +++++---- modules/sd_hijack_clip.py | 14 +++---- modules/sd_hijack_xlmr.py | 34 ++++++++++++++++ modules/shared.py | 10 +---- v1-inference.yaml | 70 --------------------------------- 8 files changed, 192 insertions(+), 168 deletions(-) create mode 100644 configs/alt-diffusion-inference.yaml delete mode 100644 configs/altdiffusion/ad-inference.yaml create mode 100644 configs/v1-inference.yaml create mode 100644 modules/sd_hijack_xlmr.py delete mode 100644 v1-inference.yaml diff --git a/configs/alt-diffusion-inference.yaml b/configs/alt-diffusion-inference.yaml new file mode 100644 index 00000000..cfbee72d --- /dev/null +++ b/configs/alt-diffusion-inference.yaml @@ -0,0 +1,72 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: modules.xlmr.BertSeriesModelWithTransformation + params: + name: "XLMR-Large" \ No newline at end of file diff --git a/configs/altdiffusion/ad-inference.yaml b/configs/altdiffusion/ad-inference.yaml deleted file mode 100644 index cfbee72d..00000000 --- a/configs/altdiffusion/ad-inference.yaml +++ /dev/null @@ -1,72 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: modules.xlmr.BertSeriesModelWithTransformation - params: - name: "XLMR-Large" \ No newline at end of file diff --git a/configs/v1-inference.yaml b/configs/v1-inference.yaml new file mode 100644 index 00000000..d4effe56 --- /dev/null +++ b/configs/v1-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 10000 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bce23b03..edcbaf52 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr from modules.sd_hijack_optimizations import invokeAI_mps_available @@ -68,6 +68,7 @@ def fix_checkpoint(): ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward + class StableDiffusionModelHijack: fixes = None comments = [] @@ -79,21 +80,22 @@ class StableDiffusionModelHijack: def hijack(self, m): - if shared.text_model_name == "XLMR-Large": + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - + m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - apply_optimizations() + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder: m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - apply_optimizations() - + + apply_optimizations() + self.clip = m.cond_stage_model fix_checkpoint() @@ -109,7 +111,7 @@ class StableDiffusionModelHijack: def undo_hijack(self, m): - if shared.text_model_name == "XLMR-Large": + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: m.cond_stage_model = m.cond_stage_model.wrapped elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 9ea6e1ce..6ec50cca 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -4,7 +4,6 @@ import torch from modules import prompt_parser, devices from modules.shared import opts -import modules.shared as shared def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 @@ -177,9 +176,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): - if shared.text_model_name == "XLMR-Large": - return self.wrapped.encode(text) - use_old = opts.use_old_emphasis_implementation if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) @@ -257,13 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) self.tokenizer = wrapped.tokenizer - if shared.text_model_name == "XLMR-Large": - self.comma_token = None - else : - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + + vocab = self.tokenizer.get_vocab() + + self.comma_token = vocab.get(',', None) self.token_mults = {} - tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 for c in text: diff --git a/modules/sd_hijack_xlmr.py b/modules/sd_hijack_xlmr.py new file mode 100644 index 00000000..4ac51c38 --- /dev/null +++ b/modules/sd_hijack_xlmr.py @@ -0,0 +1,34 @@ +import open_clip.tokenizer +import torch + +from modules import sd_hijack_clip, devices +from modules.shared import opts + + +class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.id_start = wrapped.config.bos_token_id + self.id_end = wrapped.config.eos_token_id + self.id_pad = wrapped.config.pad_token_id + + self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma + + def encode_with_transformers(self, tokens): + # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a + # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer + # layer to work with - you have to use the last + + attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64) + features = self.wrapped(input_ids=tokens, attention_mask=attention_mask) + z = features['projection_state'] + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + embedding_layer = self.wrapped.roberta.embeddings + ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) + + return embedded diff --git a/modules/shared.py b/modules/shared.py index 2b31e717..715b9169 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -23,7 +23,7 @@ demo = None sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default=os.path.join(script_path, "v1-inference.yaml"), help="path to config which constructs model",) +parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) @@ -108,14 +108,6 @@ restricted_opts = { "outdir_txt2img_grids", "outdir_save", } -from omegaconf import OmegaConf -config = OmegaConf.load(f"{cmd_opts.config}") -# XLMR-Large -try: - text_model_name = config.model.params.cond_stage_config.params.name - -except : - text_model_name = "stable_diffusion" cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access diff --git a/v1-inference.yaml b/v1-inference.yaml deleted file mode 100644 index d4effe56..00000000 --- a/v1-inference.yaml +++ /dev/null @@ -1,70 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder -- cgit v1.2.3 From f55ac33d446185680604e872ceda2ae858821d5c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 31 Dec 2022 11:27:02 -0500 Subject: validate textual inversion embeddings --- modules/sd_models.py | 3 ++ modules/textual_inversion/textual_inversion.py | 43 +++++++++++++++++++++++--- modules/ui.py | 2 -- 3 files changed, 41 insertions(+), 7 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index ecdd91c5..ebd4dff7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -325,6 +325,9 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) print("Model loaded.") + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model + return sd_model diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f6112578..103ace60 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -23,6 +23,8 @@ class Embedding: self.vec = vec self.name = name self.step = step + self.shape = None + self.vectors = 0 self.cached_checksum = None self.sd_checkpoint = None self.sd_checkpoint_name = None @@ -57,8 +59,10 @@ class EmbeddingDatabase: def __init__(self, embeddings_dir): self.ids_lookup = {} self.word_embeddings = {} + self.skipped_embeddings = [] self.dir_mtime = None self.embeddings_dir = embeddings_dir + self.expected_shape = -1 def register_embedding(self, embedding, model): @@ -75,14 +79,35 @@ class EmbeddingDatabase: return embedding - def load_textual_inversion_embeddings(self): + def get_expected_shape(self): + expected_shape = -1 # initialize with unknown + idx = torch.tensor(0).to(shared.device) + if expected_shape == -1: + try: # matches sd15 signature + first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx) + expected_shape = first_embedding.shape[0] + except: + pass + if expected_shape == -1: + try: # matches sd20 signature + first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx) + expected_shape = first_embedding.shape[0] + except: + pass + if expected_shape == -1: + print('Could not determine expected embeddings shape from model') + return expected_shape + + def load_textual_inversion_embeddings(self, force_reload = False): mt = os.path.getmtime(self.embeddings_dir) - if self.dir_mtime is not None and mt <= self.dir_mtime: + if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime: return self.dir_mtime = mt self.ids_lookup.clear() self.word_embeddings.clear() + self.skipped_embeddings = [] + self.expected_shape = self.get_expected_shape() def process_file(path, filename): name = os.path.splitext(filename)[0] @@ -122,7 +147,14 @@ class EmbeddingDatabase: embedding.step = data.get('step', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - self.register_embedding(embedding, shared.sd_model) + embedding.vectors = vec.shape[0] + embedding.shape = vec.shape[-1] + + if (self.expected_shape == -1) or (self.expected_shape == embedding.shape): + self.register_embedding(embedding, shared.sd_model) + else: + self.skipped_embeddings.append(name) + # print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape)) for fn in os.listdir(self.embeddings_dir): try: @@ -137,8 +169,9 @@ class EmbeddingDatabase: print(traceback.format_exc(), file=sys.stderr) continue - print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") - print("Embeddings:", ', '.join(self.word_embeddings.keys())) + print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys()))) + if (len(self.skipped_embeddings) > 0): + print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings))) def find_embedding_at_position(self, tokens, offset): token = tokens[offset] diff --git a/modules/ui.py b/modules/ui.py index 57ee0465..397dd804 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1157,8 +1157,6 @@ def create_ui(): with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Row().style(equal_height=False): gr.HTML(value="

See wiki for detailed explanation.

") -- cgit v1.2.3 From bdbe09827b39be63c9c0b3636132ca58da38ebf6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 31 Dec 2022 22:49:09 +0300 Subject: changed embedding accepted shape detection to use existing code and support the new alt-diffusion model, and reformatted messages a bit #6149 --- modules/textual_inversion/textual_inversion.py | 30 ++++++-------------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 103ace60..66f40367 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -80,23 +80,8 @@ class EmbeddingDatabase: return embedding def get_expected_shape(self): - expected_shape = -1 # initialize with unknown - idx = torch.tensor(0).to(shared.device) - if expected_shape == -1: - try: # matches sd15 signature - first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx) - expected_shape = first_embedding.shape[0] - except: - pass - if expected_shape == -1: - try: # matches sd20 signature - first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx) - expected_shape = first_embedding.shape[0] - except: - pass - if expected_shape == -1: - print('Could not determine expected embeddings shape from model') - return expected_shape + vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) + return vec.shape[1] def load_textual_inversion_embeddings(self, force_reload = False): mt = os.path.getmtime(self.embeddings_dir) @@ -112,8 +97,6 @@ class EmbeddingDatabase: def process_file(path, filename): name = os.path.splitext(filename)[0] - data = [] - if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: embed_image = Image.open(path) if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: @@ -150,11 +133,10 @@ class EmbeddingDatabase: embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] - if (self.expected_shape == -1) or (self.expected_shape == embedding.shape): + if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) else: self.skipped_embeddings.append(name) - # print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape)) for fn in os.listdir(self.embeddings_dir): try: @@ -169,9 +151,9 @@ class EmbeddingDatabase: print(traceback.format_exc(), file=sys.stderr) continue - print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys()))) - if (len(self.skipped_embeddings) > 0): - print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings))) + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") + if len(self.skipped_embeddings) > 0: + print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}") def find_embedding_at_position(self, tokens, offset): token = tokens[offset] -- cgit v1.2.3 From f4535f6e4f001314bd155bc6e1b6908e02792b9a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 31 Dec 2022 23:40:55 +0300 Subject: make it so that memory/embeddings info is displayed in a separate UI element from generation parameters, and is preserved when you change the displayed infotext by clicking on gallery images --- modules/img2img.py | 2 +- modules/processing.py | 5 +++-- modules/txt2img.py | 2 +- modules/ui.py | 31 +++++++++++++++++-------------- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 81da4b13..ca58b5d8 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -162,4 +162,4 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro if opts.do_not_show_images: processed.images = [] - return processed.images, generation_info_js, plaintext_to_html(processed.info) + return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) diff --git a/modules/processing.py b/modules/processing.py index 0a9a8f95..42dc19ea 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -239,7 +239,7 @@ class StableDiffusionProcessing(): class Processed: - def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): + def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""): self.images = images_list self.prompt = p.prompt self.negative_prompt = p.negative_prompt @@ -247,6 +247,7 @@ class Processed: self.subseed = subseed self.subseed_strength = p.subseed_strength self.info = info + self.comments = comments self.width = p.width self.height = p.height self.sampler_name = p.sampler_name @@ -646,7 +647,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: devices.torch_gc() - res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) + res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) if p.scripts is not None: p.scripts.postprocess(p, res) diff --git a/modules/txt2img.py b/modules/txt2img.py index c8f81176..7f61e19a 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -59,4 +59,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: if opts.do_not_show_images: processed.images = [] - return processed.images, generation_info_js, plaintext_to_html(processed.info) + return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) diff --git a/modules/ui.py b/modules/ui.py index 397dd804..f550ad00 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -159,7 +159,7 @@ def save_files(js_data, images, do_make_zip, index): zip_file.writestr(filenames[i], f.read()) fullfns.insert(0, zip_filepath) - return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}") + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") @@ -593,6 +593,8 @@ Requested path was: {f} with gr.Group(): html_info = gr.HTML() + html_log = gr.HTML() + generation_info = gr.Textbox(visible=False) if tabname == 'txt2img' or tabname == 'img2img': generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") @@ -615,16 +617,16 @@ Requested path was: {f} ], outputs=[ download_files, - html_info, - html_info, - html_info, + html_log, ] ) else: html_info_x = gr.HTML() html_info = gr.HTML() + html_log = gr.HTML() + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log def create_ui(): @@ -686,14 +688,14 @@ def create_ui(): with gr.Group(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", inputs=[ txt2img_prompt, @@ -720,7 +722,8 @@ def create_ui(): outputs=[ txt2img_gallery, generation_info, - html_info + html_info, + html_log, ], show_progress=False, ) @@ -799,7 +802,6 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -883,7 +885,7 @@ def create_ui(): with gr.Group(): custom_inputs = modules.scripts.scripts_img2img.setup_ui() - img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) @@ -915,7 +917,7 @@ def create_ui(): ) img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img), + fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", inputs=[ dummy_component, @@ -954,7 +956,8 @@ def create_ui(): outputs=[ img2img_gallery, generation_info, - html_info + html_info, + html_log, ], show_progress=False, ) @@ -1078,10 +1081,10 @@ def create_ui(): with gr.Group(): upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) - result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) + result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras), + fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), _js="get_extras_tab_index", inputs=[ dummy_component, -- cgit v1.2.3 From 360feed9b55fb03060c236773867b08b4265645d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 00:38:58 +0300 Subject: HAPPY NEW YEAR make save to zip into its own button instead of a checkbox --- modules/ui.py | 30 ++++++++++++++++++++++-------- style.css | 6 ++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index f550ad00..279b5110 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -570,13 +570,14 @@ Requested path was: {f} generation_info = None with gr.Column(): - with gr.Row(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder') + if tabname != "extras": save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder' - open_folder_button = gr.Button(folder_symbol, elem_id=button_id) open_folder_button.click( fn=lambda: open_folder(opts.outdir_samples or outdir), @@ -585,9 +586,6 @@ Requested path was: {f} ) if tabname != "extras": - with gr.Row(): - do_make_zip = gr.Checkbox(label="Make Zip when Save?", value=False) - with gr.Row(): download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) @@ -608,11 +606,11 @@ Requested path was: {f} save.click( fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, z, selected_gallery_index()]", + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", inputs=[ generation_info, result_gallery, - do_make_zip, + html_info, html_info, ], outputs=[ @@ -620,6 +618,22 @@ Requested path was: {f} html_log, ] ) + + save_zip.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + else: html_info_x = gr.HTML() html_info = gr.HTML() diff --git a/style.css b/style.css index 3ad78006..f245f674 100644 --- a/style.css +++ b/style.css @@ -568,6 +568,12 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h font-size: 95%; } +#image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{ + min-width: auto; + padding-left: 0.5em; + padding-right: 0.5em; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 29a3a7eb13478297bc7093971b48827ab8246f45 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 01:19:10 +0300 Subject: show sampler selection in dropdown, add option selection to revert to old radio group --- modules/shared.py | 1 + modules/ui.py | 22 +++++++++++++++------- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 715b9169..948b9542 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -406,6 +406,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index 279b5110..c7b8ea5d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -643,6 +643,19 @@ Requested path was: {f} return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log +def create_sampler_and_steps_selection(choices, tabname): + if opts.samplers_in_dropdown: + with gr.Row(elem_id=f"sampler_selection_{tabname}"): + sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) + else: + with gr.Group(elem_id=f"sampler_selection_{tabname}"): + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + + return steps, sampler_index + + def create_ui(): import modules.img2img import modules.txt2img @@ -660,9 +673,6 @@ def create_ui(): dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - - with gr.Row(elem_id='txt2img_progress_row'): with gr.Column(scale=1): pass @@ -674,8 +684,7 @@ def create_ui(): with gr.Row().style(equal_height=False): with gr.Column(variant='panel', elem_id="txt2img_settings"): - steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index") + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") with gr.Group(): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) @@ -875,8 +884,7 @@ def create_ui(): with gr.Row(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - steps = gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20) - sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index") + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") with gr.Group(): width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") -- cgit v1.2.3 From 210449b374d522c94a67fe54289a9eb515933a9f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 02:41:15 +0300 Subject: fix 'RuntimeError: Expected all tensors to be on the same device' error preventing models from loading on lowvram/medvram. --- modules/sd_hijack_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 6ec50cca..ca92b142 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -298,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): def encode_embedding_init_text(self, init_text, nvpt): embedding_layer = self.wrapped.transformer.text_model.embeddings ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] - embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) + embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) return embedded -- cgit v1.2.3 From a939e82a0b982517aa212197a0e5f6d11daec7d0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 03:24:58 +0300 Subject: fix weird padding for sampler dropdown in chrome --- style.css | 5 ----- 1 file changed, 5 deletions(-) diff --git a/style.css b/style.css index f245f674..4b98b84d 100644 --- a/style.css +++ b/style.css @@ -245,11 +245,6 @@ input[type="range"]{ margin: 0.5em 0 -0.3em 0; } -#txt2img_sampling label{ - padding-left: 0.6em; - padding-right: 0.6em; -} - #mask_bug_info { text-align: center; display: block; -- cgit v1.2.3 From 16b9661d2741b241c3964fcbd56559c078b84822 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 09:51:37 +0300 Subject: change karras scheduler sigmas to values recommended by SD from old 0.1 to 10 with an option to revert to old --- modules/sd_samplers.py | 4 +++- modules/shared.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 177b5338..e904d860 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -465,7 +465,9 @@ class KDiffusionSampler: if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) + sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) else: sigmas = self.model_wrap.get_sigmas(steps) diff --git a/modules/shared.py b/modules/shared.py index 948b9542..7f430b93 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -368,13 +368,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), - "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) +options_templates.update(options_section(('compatibility', "Compatibility"), { + "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), + "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), +})) + options_templates.update(options_section(('interrogate', "Interrogate Options"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), -- cgit v1.2.3 From 11d432d92d63660c516540dcb48faac87669b4f0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 10:35:38 +0300 Subject: add refresh buttons to checkpoint merger --- modules/ui.py | 6 ++++++ style.css | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index c7b8ea5d..4cc2ce4f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1167,8 +1167,14 @@ def create_ui(): with gr.Row(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + custom_name = gr.Textbox(label="Custom Name (Optional)") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") diff --git a/style.css b/style.css index 4b98b84d..516ef7bf 100644 --- a/style.css +++ b/style.css @@ -496,7 +496,7 @@ input[type="range"]{ padding: 0; } -#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{ +#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_checkpoint_A, #refresh_checkpoint_B, #refresh_checkpoint_C{ max-width: 2.5em; min-width: 2.5em; height: 2.4em; -- cgit v1.2.3 From 76f256fe8f844641f4e9b41f35c7dd2cba5090d6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 11:08:39 +0300 Subject: Bump gradio version #YOLO --- modules/ui_tempdir.py | 3 ++- requirements.txt | 2 +- requirements_versions.txt | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index 07210d14..8d519310 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -15,7 +15,8 @@ Savedfile = namedtuple("Savedfile", ["name"]) def save_pil_to_file(pil_image, dir=None): already_saved_as = getattr(pil_image, 'already_saved_as', None) if already_saved_as and os.path.isfile(already_saved_as): - shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))} + shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(already_saved_as)} + file_obj = Savedfile(already_saved_as) return file_obj diff --git a/requirements.txt b/requirements.txt index 5bed694e..e2c3876b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ fairscale==0.4.4 fonts font-roboto gfpgan -gradio==3.9 +gradio==3.15.0 invisible-watermark numpy omegaconf diff --git a/requirements_versions.txt b/requirements_versions.txt index c126c8c4..836523ba 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -3,7 +3,7 @@ transformers==4.19.2 accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 -gradio==3.9 +gradio==3.15.0 numpy==1.23.3 Pillow==9.2.0 realesrgan==0.3.0 -- cgit v1.2.3 From b46b97fa297b3a4a654da77cf98a775a2bcab4c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 11:38:17 +0300 Subject: more fixes for gradio update --- modules/generation_parameters_copypaste.py | 2 +- modules/ui_tempdir.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index fbd91300..54b3372d 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -38,7 +38,7 @@ def quote(text): def image_from_url_text(filedata): if type(filedata) == dict and filedata["is_file"]: filename = filedata["name"] - is_in_right_dir = any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in shared.demo.temp_dirs) + is_in_right_dir = any([filename in fileset for fileset in shared.demo.temp_file_sets]) assert is_in_right_dir, 'trying to open image file outside of allowed directories' return Image.open(filename) diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index 8d519310..363d449d 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -45,7 +45,7 @@ def on_tmpdir_changed(): os.makedirs(shared.opts.temp_dir, exist_ok=True) - shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)} + shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(shared.opts.temp_dir)} def cleanup_tmpdr(): -- cgit v1.2.3 From e5f1a37cb9b537d95b2df47c96b4a4f7242fd294 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 13:08:40 +0300 Subject: make refresh buttons look more nice --- modules/ui.py | 6 +++--- modules/ui_components.py | 18 ++++++++++++++++++ style.css | 28 +++++++++++++++++++++------- 3 files changed, 42 insertions(+), 10 deletions(-) create mode 100644 modules/ui_components.py diff --git a/modules/ui.py b/modules/ui.py index 4cc2ce4f..32fa80d1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -19,7 +19,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, ui_components from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts @@ -532,7 +532,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele return gr.update(**(args or {})) - refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) + refresh_button = ui_components.ToolButton(value=refresh_symbol, elem_id=elem_id) refresh_button.click( fn=refresh, inputs=[], @@ -1476,7 +1476,7 @@ def create_ui(): res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: - with gr.Row(variant="compact"): + with ui_components.FormRow(): res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: diff --git a/modules/ui_components.py b/modules/ui_components.py new file mode 100644 index 00000000..d0519d2d --- /dev/null +++ b/modules/ui_components.py @@ -0,0 +1,18 @@ +import gradio as gr + + +class ToolButton(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool", **kwargs) + + def get_block_name(self): + return "button" + + +class FormRow(gr.Row, gr.components.FormComponent): + """Same as gr.Row but fits inside gradio forms""" + + def get_block_name(self): + return "row" diff --git a/style.css b/style.css index 516ef7bf..f168571e 100644 --- a/style.css +++ b/style.css @@ -496,13 +496,6 @@ input[type="range"]{ padding: 0; } -#refresh_sd_model_checkpoint, #refresh_sd_vae, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_checkpoint_A, #refresh_checkpoint_B, #refresh_checkpoint_C{ - max-width: 2.5em; - min-width: 2.5em; - height: 2.4em; -} - - canvas[key="mask"] { z-index: 12 !important; filter: invert(); @@ -569,6 +562,27 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h padding-right: 0.5em; } +.gr-form{ + background-color: white; +} + +.dark .gr-form{ + background-color: rgb(31 41 55 / var(--tw-bg-opacity)); +} + +.gr-button-tool{ + max-width: 2.5em; + min-width: 2.5em !important; + height: 2.4em; + margin: 0.55em 0; +} + +#quicksettings .gr-button-tool{ + margin: 0; +} + + + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 5f12b23b8bb7fca585a3a1e844881d06f171364e Mon Sep 17 00:00:00 2001 From: AlUlkesh <99896447+AlUlkesh@users.noreply.github.com> Date: Wed, 28 Dec 2022 22:18:19 +0100 Subject: Adding image numbers on grids New grid option in settings enables adding of image numbers on grids. This makes identifying the images, especially in larger batches, much easier. Revert "Adding image numbers on grids" This reverts commit 3530c283b4b1d3a3cab40efbffe4cf2697938b6f. Implements Callback for image grid loop Necessary to make "Add image's number to its picture in the grid" extension possible. --- modules/images.py | 1 + modules/script_callbacks.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/modules/images.py b/modules/images.py index 31d4528d..5afd3891 100644 --- a/modules/images.py +++ b/modules/images.py @@ -43,6 +43,7 @@ def image_grid(imgs, batch_size=1, rows=None): grid = Image.new('RGB', size=(cols * w, rows * h), color='black') for i, img in enumerate(imgs): + script_callbacks.image_grid_loop_callback(img) grid.paste(img, box=(i % cols * w, i // cols * h)) return grid diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 8e22f875..0c854407 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -51,6 +51,11 @@ class UiTrainTabParams: self.txt2img_preview_params = txt2img_preview_params +class ImageGridLoopParams: + def __init__(self, img): + self.img = img + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callback_map = dict( callbacks_app_started=[], @@ -63,6 +68,7 @@ callback_map = dict( callbacks_cfg_denoiser=[], callbacks_before_component=[], callbacks_after_component=[], + callbacks_image_grid_loop=[], ) @@ -154,6 +160,12 @@ def after_component_callback(component, **kwargs): except Exception: report_exception(c, 'after_component_callback') +def image_grid_loop_callback(component, **kwargs): + for c in callback_map['callbacks_image_grid_loop']: + try: + c.callback(component, **kwargs) + except Exception: + report_exception(c, 'image_grid_loop') def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] @@ -255,3 +267,11 @@ def on_before_component(callback): def on_after_component(callback): """register a function to be called after a component is created. See on_before_component for more.""" add_callback(callback_map['callbacks_after_component'], callback) + + +def on_image_grid_loop(callback): + """register a function to be called inside the image grid loop. + The callback is called with one argument: + - params: ImageGridLoopParams - parameters to be used inside the image grid loop. + """ + add_callback(callback_map['callbacks_image_grid_loop'], callback) -- cgit v1.2.3 From 524d532b387732d4d32f237e792c7f201a934400 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 14:07:40 +0300 Subject: moved roll artist to built-in extensions --- .../roll-artist/scripts/roll-artist.py | 50 ++++++++++++++++++++++ modules/ui.py | 37 ++-------------- 2 files changed, 53 insertions(+), 34 deletions(-) create mode 100644 extensions-builtin/roll-artist/scripts/roll-artist.py diff --git a/extensions-builtin/roll-artist/scripts/roll-artist.py b/extensions-builtin/roll-artist/scripts/roll-artist.py new file mode 100644 index 00000000..c3bc1fd0 --- /dev/null +++ b/extensions-builtin/roll-artist/scripts/roll-artist.py @@ -0,0 +1,50 @@ +import random + +from modules import script_callbacks, shared +import gradio as gr + +art_symbol = '\U0001f3a8' # 🎨 +global_prompt = None +related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" } + + +def roll_artist(prompt): + allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories]) + artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) + + return prompt + ", " + artist.name if prompt != '' else artist.name + + +def add_roll_button(prompt): + roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) + + roll.click( + fn=roll_artist, + _js="update_txt2img_tokens", + inputs=[ + prompt, + ], + outputs=[ + prompt, + ] + ) + + +def after_component(component, **kwargs): + global global_prompt + + elem_id = kwargs.get('elem_id', None) + if elem_id not in related_ids: + return + + if elem_id == "txt2img_prompt": + global_prompt = component + elif elem_id == "txt2img_clear_prompt": + add_roll_button(global_prompt) + elif elem_id == "img2img_prompt": + global_prompt = component + elif elem_id == "img2img_clear_prompt": + add_roll_button(global_prompt) + + +script_callbacks.on_after_component(after_component) diff --git a/modules/ui.py b/modules/ui.py index 32fa80d1..27da2c2c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -80,7 +80,6 @@ css_hide_progressbar = """ # Important that they exactly match script.js for tooltip to work. random_symbol = '\U0001f3b2\ufe0f' # 🎲️ reuse_symbol = '\u267b\ufe0f' # ♻️ -art_symbol = '\U0001f3a8' # 🎨 paste_symbol = '\u2199\ufe0f' # ↙ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -234,13 +233,6 @@ def check_progress_call_initial(id_part): return check_progress_call(id_part) -def roll_artist(prompt): - allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories]) - artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) - - return prompt + ", " + artist.name if prompt != '' else artist.name - - def visit(x, func, path=""): if hasattr(x, 'children'): for c in x.children: @@ -403,7 +395,6 @@ def create_toprow(is_img2img): ) with gr.Column(scale=1, elem_id="roll_col"): - roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") @@ -452,7 +443,7 @@ def create_toprow(is_img2img): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2.save_to_config = True - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -668,7 +659,7 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -771,16 +762,6 @@ def create_ui(): outputs=[hr_options], ) - roll.click( - fn=roll_artist, - _js="update_txt2img_tokens", - inputs=[ - txt2img_prompt, - ], - outputs=[ - txt2img_prompt, - ] - ) txt2img_paste_fields = [ (txt2img_prompt, "Prompt"), @@ -823,7 +804,7 @@ def create_ui(): modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -999,18 +980,6 @@ def create_ui(): outputs=[img2img_prompt], ) - - roll.click( - fn=roll_artist, - _js="update_img2img_tokens", - inputs=[ - img2img_prompt, - ], - outputs=[ - img2img_prompt, - ] - ) - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] -- cgit v1.2.3 From e672cfb07418a1a3130d3bf21c14a0d3819f81fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 18:37:37 +0300 Subject: rework of callback for #6094 --- modules/images.py | 10 ++++++---- modules/script_callbacks.py | 26 +++++++++++++++----------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/modules/images.py b/modules/images.py index 719aaf3b..f84fd485 100644 --- a/modules/images.py +++ b/modules/images.py @@ -39,12 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None): cols = math.ceil(len(imgs) / rows) + params = script_callbacks.ImageGridLoopParams(imgs, cols, rows) + script_callbacks.image_grid_callback(params) + w, h = imgs[0].size - grid = Image.new('RGB', size=(cols * w, rows * h), color='black') + grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black') - for i, img in enumerate(imgs): - script_callbacks.image_grid_loop_callback(img) - grid.paste(img, box=(i % cols * w, i // cols * h)) + for i, img in enumerate(params.imgs): + grid.paste(img, box=(i % params.cols * w, i // params.cols * h)) return grid diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 0c854407..de69fd9f 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -52,8 +52,10 @@ class UiTrainTabParams: class ImageGridLoopParams: - def __init__(self, img): - self.img = img + def __init__(self, imgs, cols, rows): + self.imgs = imgs + self.cols = cols + self.rows = rows ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) @@ -68,7 +70,7 @@ callback_map = dict( callbacks_cfg_denoiser=[], callbacks_before_component=[], callbacks_after_component=[], - callbacks_image_grid_loop=[], + callbacks_image_grid=[], ) @@ -160,12 +162,14 @@ def after_component_callback(component, **kwargs): except Exception: report_exception(c, 'after_component_callback') -def image_grid_loop_callback(component, **kwargs): - for c in callback_map['callbacks_image_grid_loop']: + +def image_grid_callback(params: ImageGridLoopParams): + for c in callback_map['callbacks_image_grid']: try: - c.callback(component, **kwargs) + c.callback(params) except Exception: - report_exception(c, 'image_grid_loop') + report_exception(c, 'image_grid') + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] @@ -269,9 +273,9 @@ def on_after_component(callback): add_callback(callback_map['callbacks_after_component'], callback) -def on_image_grid_loop(callback): - """register a function to be called inside the image grid loop. +def on_image_grid(callback): + """register a function to be called before making an image grid. The callback is called with one argument: - - params: ImageGridLoopParams - parameters to be used inside the image grid loop. + - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ - add_callback(callback_map['callbacks_image_grid_loop'], callback) + add_callback(callback_map['callbacks_image_grid'], callback) -- cgit v1.2.3 From a005fccddd5a37c57f1afe5234660b59b9a41508 Mon Sep 17 00:00:00 2001 From: me <25877290+Kryptortio@users.noreply.github.com> Date: Sun, 1 Jan 2023 14:51:12 +0100 Subject: Add a lot more elem_id/HTML id, modified some that were duplicates for seed section --- modules/generation_parameters_copypaste.py | 2 +- modules/ui.py | 254 ++++++++++++++--------------- style.css | 12 +- 3 files changed, 134 insertions(+), 134 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 54b3372d..8e7f0df0 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -93,7 +93,7 @@ def integrate_settings_paste_fields(component_dict): def create_buttons(tabs_list): buttons = {} for tab in tabs_list: - buttons[tab] = gr.Button(f"Send to {tab}") + buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab") return buttons diff --git a/modules/ui.py b/modules/ui.py index 27da2c2c..7070ea15 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -272,17 +272,17 @@ def interrogate_deepbooru(image): return gr_show(True) if prompt is None else prompt -def create_seed_inputs(): +def create_seed_inputs(target_interface): with gr.Row(): with gr.Box(): - with gr.Row(elem_id='seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1) + with gr.Row(elem_id=target_interface + '_seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id='random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id='reuse_seed') + random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - with gr.Box(elem_id='subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id='subseed_show', value=False) + with gr.Box(elem_id=target_interface + '_subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) # Components to show/hide based on the 'Extra' checkbox seed_extras = [] @@ -290,17 +290,17 @@ def create_seed_inputs(): with gr.Row(visible=False) as seed_extra_row_1: seed_extras.append(seed_extra_row_1) with gr.Box(): - with gr.Row(elem_id='subseed_row'): - subseed = gr.Number(label='Variation seed', value=-1) + with gr.Row(elem_id=target_interface + '_subseed_row'): + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id='random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id='reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01) + random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') with gr.Row(visible=False) as seed_extra_row_2: seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0) - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) @@ -678,28 +678,28 @@ def create_ui(): steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512) + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") with gr.Row(): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) - tiling = gr.Checkbox(label='Tiling', value=False) - enable_hr = gr.Checkbox(label='Highres. fix', value=False) + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr") with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0) - firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0) - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7) + firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width") + firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") with gr.Row(equal_height=True): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - with gr.Group(): + with gr.Group(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) @@ -821,10 +821,10 @@ def create_ui(): with gr.Column(variant='panel', elem_id="img2img_settings"): with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img'): + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) - with gr.TabItem('Inpaint', id='inpaint'): + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) init_img_with_mask_orig = gr.State(None) @@ -843,24 +843,24 @@ def create_ui(): init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") with gr.Row(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") with gr.Row(): mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") + inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index") + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") with gr.Row(): - inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) - inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) + inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False, elem_id="img2img_inpaint_full_res") + inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - with gr.TabItem('Batch img2img', id='batch'): + with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs) - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs) + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") with gr.Row(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") @@ -872,20 +872,20 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") with gr.Row(): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) - tiling = gr.Checkbox(label='Tiling', value=False) + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") with gr.Row(): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1) - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1) + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") with gr.Group(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0) - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75) + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - with gr.Group(): + with gr.Group(elem_id="img2img_script_container"): custom_inputs = modules.scripts.scripts_img2img.setup_ui() img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) @@ -1032,45 +1032,45 @@ def create_ui(): with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image'): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil") + with gr.TabItem('Single Image', elem_id="extras_single_tab"): + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - with gr.TabItem('Batch Process'): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file") + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - with gr.TabItem('Batch from Directory'): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") - show_extras_results = gr.Checkbox(label='Show result images', value=True) + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by'): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) - with gr.TabItem('Scale to'): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): with gr.Group(): with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") with gr.Group(): extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan) + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) @@ -1117,7 +1117,7 @@ def create_ui(): with gr.Column(variant='panel'): html = gr.HTML() - generation_info = gr.Textbox(visible=False) + generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") html2 = gr.HTML() with gr.Row(): buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) @@ -1144,13 +1144,13 @@ def create_ui(): tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") - custom_name = gr.Textbox(label="Custom Name (Optional)") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") + custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format") - save_as_half = gr.Checkbox(value=False, label="Save as float16") + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') @@ -1165,58 +1165,58 @@ def create_ui(): with gr.Tabs(elem_id="train_tabs"): with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name") - initialization_text = gr.Textbox(label="Initialization text", value="*") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding") + new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary') + create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"]) - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys) - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") + new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory') - process_dst = gr.Textbox(label='Destination directory') - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512) - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"]) + process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies') - process_split = gr.Checkbox(label='Split oversized images') - process_focal_crop = gr.Checkbox(label='Auto focal point crop') - process_caption = gr.Checkbox(label='Use BLIP for caption') - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True) + process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05) + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05) - process_focal_crop_debug = gr.Checkbox(label='Create debug image') + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") with gr.Row(): with gr.Column(scale=3): @@ -1224,8 +1224,8 @@ def create_ui(): with gr.Column(): with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt") - run_preprocess = gr.Button(value="Preprocess", variant='primary') + interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") process_split.change( fn=lambda show: gr_show(show), @@ -1248,31 +1248,31 @@ def create_ui(): train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") with gr.Row(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") - - batch_size = gr.Number(label='Batch size', value=1, precision=0) - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0) - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512) - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512) - steps = gr.Number(label='Max steps', value=100000, precision=0) - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") with gr.Row(): - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False) - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0) + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") with gr.Row(): - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random']) + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") with gr.Row(): - interrupt_training = gr.Button(value="Interrupt") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') - train_embedding = gr.Button(value="Train Embedding", variant='primary') + interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") params = script_callbacks.UiTrainTabParams(txt2img_preview_params) @@ -1490,7 +1490,7 @@ def create_ui(): return gr.update(value=value), opts.dumpjson() with gr.Blocks(analytics_enabled=False) as settings_interface: - settings_submit = gr.Button(value="Apply settings", variant='primary') + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") result = gr.HTML() settings_cols = 3 @@ -1541,8 +1541,8 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary', elem_id="settings_restart_gradio") request_notifications.click( fn=lambda: None, diff --git a/style.css b/style.css index f168571e..924d4ae7 100644 --- a/style.css +++ b/style.css @@ -73,7 +73,7 @@ margin-right: auto; } -#random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{ +[id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{ min-width: auto; flex-grow: 0; padding-left: 0.25em; @@ -84,27 +84,27 @@ display: none; } -#seed_row, #subseed_row{ +[id$=_seed_row], [id$=_subseed_row]{ gap: 0.5rem; } -#subseed_show_box{ +[id$=_subseed_show_box]{ min-width: auto; flex-grow: 0; } -#subseed_show_box > div{ +[id$=_subseed_show_box] > div{ border: 0; height: 100%; } -#subseed_show{ +[id$=_subseed_show]{ min-width: auto; flex-grow: 0; padding: 0; } -#subseed_show label{ +[id$=_subseed_show] label{ height: 100%; } -- cgit v1.2.3 From 311354c0bb8930ea939d6aa6b3edd50c69301320 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 00:38:09 +0300 Subject: fix the issue with training on SD2.0 --- modules/sd_models.py | 2 ++ modules/textual_inversion/textual_inversion.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index ebd4dff7..bff8d6c9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + model.logvar = model.logvar.to(devices.device) # fix for training + sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 66f40367..1e5722e7 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -282,7 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ return embedding, filename scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - # dataset loading may take a while, so input validations and early returns should be done before this + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed @@ -310,7 +310,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ loss_step = 0 _loss_step = 0 #internal - last_saved_file = "" last_saved_image = "" forced_filename = "" -- cgit v1.2.3 From b5819d9bf1794071139c640b5f1e72c84a0e051a Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 2 Jan 2023 10:17:33 +1100 Subject: feat(api): add /sdapi/v1/embeddings --- modules/api/api.py | 8 ++++++++ modules/api/models.py | 3 +++ 2 files changed, 11 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index 11daff0d..30bf3dac 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -100,6 +100,7 @@ class Api: self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) + self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse) @@ -327,6 +328,13 @@ class Api: def get_artists(self): return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] + def get_embeddings(self): + db = sd_hijack.model_hijack.embedding_db + return { + "loaded": sorted(db.word_embeddings.keys()), + "skipped": sorted(db.skipped_embeddings), + } + def refresh_checkpoints(self): shared.refresh_checkpoints() diff --git a/modules/api/models.py b/modules/api/models.py index c446ce7a..a8472dc9 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -249,3 +249,6 @@ class ArtistItem(BaseModel): score: float = Field(title="Score") category: str = Field(title="Category") +class EmbeddingsResponse(BaseModel): + loaded: List[str] = Field(title="Loaded", description="Embeddings loaded for the current model") + skipped: List[str] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") \ No newline at end of file -- cgit v1.2.3 From c65909ad16a1962129114c6251de092f49479b06 Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 2 Jan 2023 12:21:22 +1100 Subject: feat(api): return more data for embeddings --- modules/api/api.py | 17 +++++++++++++++-- modules/api/models.py | 11 +++++++++-- modules/textual_inversion/textual_inversion.py | 8 ++++---- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 30bf3dac..9c670f00 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -330,9 +330,22 @@ class Api: def get_embeddings(self): db = sd_hijack.model_hijack.embedding_db + + def convert_embedding(embedding): + return { + "step": embedding.step, + "sd_checkpoint": embedding.sd_checkpoint, + "sd_checkpoint_name": embedding.sd_checkpoint_name, + "shape": embedding.shape, + "vectors": embedding.vectors, + } + + def convert_embeddings(embeddings): + return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()} + return { - "loaded": sorted(db.word_embeddings.keys()), - "skipped": sorted(db.skipped_embeddings), + "loaded": convert_embeddings(db.word_embeddings), + "skipped": convert_embeddings(db.skipped_embeddings), } def refresh_checkpoints(self): diff --git a/modules/api/models.py b/modules/api/models.py index a8472dc9..4a632c68 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -249,6 +249,13 @@ class ArtistItem(BaseModel): score: float = Field(title="Score") category: str = Field(title="Category") +class EmbeddingItem(BaseModel): + step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available") + sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available") + sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead") + shape: int = Field(title="Shape", description="The length of each individual vector in the embedding") + vectors: int = Field(title="Vectors", description="The number of vectors in the embedding") + class EmbeddingsResponse(BaseModel): - loaded: List[str] = Field(title="Loaded", description="Embeddings loaded for the current model") - skipped: List[str] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") \ No newline at end of file + loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") + skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") \ No newline at end of file diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1e5722e7..fd253477 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -59,7 +59,7 @@ class EmbeddingDatabase: def __init__(self, embeddings_dir): self.ids_lookup = {} self.word_embeddings = {} - self.skipped_embeddings = [] + self.skipped_embeddings = {} self.dir_mtime = None self.embeddings_dir = embeddings_dir self.expected_shape = -1 @@ -91,7 +91,7 @@ class EmbeddingDatabase: self.dir_mtime = mt self.ids_lookup.clear() self.word_embeddings.clear() - self.skipped_embeddings = [] + self.skipped_embeddings.clear() self.expected_shape = self.get_expected_shape() def process_file(path, filename): @@ -136,7 +136,7 @@ class EmbeddingDatabase: if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) else: - self.skipped_embeddings.append(name) + self.skipped_embeddings[name] = embedding for fn in os.listdir(self.embeddings_dir): try: @@ -153,7 +153,7 @@ class EmbeddingDatabase: print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: - print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}") + print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): token = tokens[offset] -- cgit v1.2.3 From ef27a18b6b7cb1a8eebdc9b2e88d25baf2c2414d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 19:42:10 +0300 Subject: Hires fix rework --- modules/generation_parameters_copypaste.py | 32 ++++++++++++++ modules/images.py | 24 +++++++++-- modules/processing.py | 68 ++++++++++++------------------ modules/shared.py | 7 ++- modules/txt2img.py | 6 +-- modules/ui.py | 15 +++---- scripts/xy_grid.py | 4 +- 7 files changed, 96 insertions(+), 60 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 8e7f0df0..d6fa822b 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,5 +1,6 @@ import base64 import io +import math import os import re from pathlib import Path @@ -164,6 +165,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None): return None +def restore_old_hires_fix_params(res): + """for infotexts that specify old First pass size parameter, convert it into + width, height, and hr scale""" + + firstpass_width = res.get('First pass size-1', None) + firstpass_height = res.get('First pass size-2', None) + + if firstpass_width is None or firstpass_height is None: + return + + firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height) + width = int(res.get("Size-1", 512)) + height = int(res.get("Size-2", 512)) + + if firstpass_width == 0 or firstpass_height == 0: + # old algorithm for auto-calculating first pass size + desired_pixel_count = 512 * 512 + actual_pixel_count = width * height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + firstpass_width = math.ceil(scale * width / 64) * 64 + firstpass_height = math.ceil(scale * height / 64) * 64 + + hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height + + res['Size-1'] = firstpass_width + res['Size-2'] = firstpass_height + res['Hires upscale'] = hr_scale + + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -221,6 +251,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hypernet_hash = res.get("Hypernet hash", None) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + restore_old_hires_fix_params(res) + return res diff --git a/modules/images.py b/modules/images.py index f84fd485..c3a5fc8b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -230,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts): return draw_grid_annotations(im, width, height, hor_texts, ver_texts) -def resize_image(resize_mode, im, width, height): +def resize_image(resize_mode, im, width, height, upscaler_name=None): + """ + Resizes an image with the specified resize_mode, width, and height. + + Args: + resize_mode: The mode to use when resizing the image. + 0: Resize the image to the specified width and height. + 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. + 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. + im: The image to resize. + width: The width to resize the image to. + height: The height to resize the image to. + upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img. + """ + + upscaler_name = upscaler_name or opts.upscaler_for_img2img + def resize(im, w, h): - if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': + if upscaler_name is None or upscaler_name == "None" or im.mode == 'L': return im.resize((w, h), resample=LANCZOS) scale = max(w / im.width, h / im.height) if scale > 1.0: - upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img] - assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}" + upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name] + assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}" upscaler = upscalers[0] im = upscaler.scaler.upscale(im, scale, upscaler.data_path) diff --git a/modules/processing.py b/modules/processing.py index 42dc19ea..4654570c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -658,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength - self.firstphase_width = firstphase_width - self.firstphase_height = firstphase_height - self.truncate_x = 0 - self.truncate_y = 0 + self.hr_scale = hr_scale + self.hr_upscaler = hr_upscaler + + if firstphase_width != 0 or firstphase_height != 0: + print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) + self.hr_scale = self.width / firstphase_width + self.width = firstphase_width + self.height = firstphase_height def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -674,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" - - if self.firstphase_width == 0 or self.firstphase_height == 0: - desired_pixel_count = 512 * 512 - actual_pixel_count = self.width * self.height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) - self.firstphase_width = math.ceil(scale * self.width / 64) * 64 - self.firstphase_height = math.ceil(scale * self.height / 64) * 64 - firstphase_width_truncated = int(scale * self.width) - firstphase_height_truncated = int(scale * self.height) - - else: - - width_ratio = self.width / self.firstphase_width - height_ratio = self.height / self.firstphase_height - - if width_ratio > height_ratio: - firstphase_width_truncated = self.firstphase_width - firstphase_height_truncated = self.firstphase_width * self.height / self.width - else: - firstphase_width_truncated = self.firstphase_height * self.width / self.height - firstphase_height_truncated = self.firstphase_height - - self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f - self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f + self.extra_generation_params["Hires upscale"] = self.hr_scale + if self.hr_upscaler is not None: + self.extra_generation_params["Hires upscaler"] = self.hr_upscaler def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode + if self.enable_hr and latent_scale_mode is None: + assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" + + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + if not self.enable_hr: - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) return samples - x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) - - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] + target_width = int(self.width * self.hr_scale) + target_height = int(self.height * self.hr_scale) - """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" def save_intermediate(image, index): + """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" + if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: return @@ -723,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") - if opts.use_scale_latent_for_hires_fix: + if latent_scale_mode is not None: for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -747,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): save_intermediate(image, i) - image = images.resize_image(0, image, self.width, self.height) + image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) batch_images.append(image) @@ -764,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory x = None diff --git a/modules/shared.py b/modules/shared.py index 7f430b93..b65559ee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -327,7 +327,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), - "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { @@ -545,6 +544,12 @@ opts = Options() if os.path.exists(config_filename): opts.load(config_filename) +latent_upscale_default_mode = "Latent" +latent_upscale_modes = { + "Latent": "bilinear", + "Latent (nearest)": "nearest", +} + sd_upscalers = [] sd_model = None diff --git a/modules/txt2img.py b/modules/txt2img.py index 7f61e19a..e189a899 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: tiling=tiling, enable_hr=enable_hr, denoising_strength=denoising_strength if enable_hr else None, - firstphase_width=firstphase_width if enable_hr else None, - firstphase_height=firstphase_height if enable_hr else None, + hr_scale=hr_scale, + hr_upscaler=hr_upscaler, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 7070ea15..27cd9ddd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -684,11 +684,11 @@ def create_ui(): with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width") - firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height") + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") with gr.Row(equal_height=True): @@ -729,8 +729,8 @@ def create_ui(): width, enable_hr, denoising_strength, - firstphase_width, - firstphase_height, + hr_scale, + hr_upscaler, ] + custom_inputs, outputs=[ @@ -762,7 +762,6 @@ def create_ui(): outputs=[hr_options], ) - txt2img_paste_fields = [ (txt2img_prompt, "Prompt"), (txt2img_negative_prompt, "Negative prompt"), @@ -781,8 +780,8 @@ def create_ui(): (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (firstphase_width, "First pass size-1"), - (firstphase_height, "First pass size-2"), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), *modules.scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 3e0b2805..f92f9776 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -202,7 +202,7 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None), + AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None), AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), AxisOption("VAE", str, apply_vae, format_value_add_label, None), AxisOption("Styles", str, apply_styles, format_value_add_label, None), @@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.hypernetwork = opts.sd_hypernetwork self.model = shared.sd_model - self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): @@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object): hypernetwork.apply_strength() opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers - opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -- cgit v1.2.3 From 4dbde228ff48dbb105241b1ed25c21ce3f87d182 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 20:01:16 +0300 Subject: make it possible to use fractional values for SD upscale. --- modules/upscaler.py | 6 +++--- scripts/sd_upscale.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/upscaler.py b/modules/upscaler.py index c4e6e6bd..231680cb 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -53,10 +53,10 @@ class Upscaler: def do_upscale(self, img: PIL.Image, selected_model: str): return img - def upscale(self, img: PIL.Image, scale: int, selected_model: str = None): + def upscale(self, img: PIL.Image, scale, selected_model: str = None): self.scale = scale - dest_w = img.width * scale - dest_h = img.height * scale + dest_w = int(img.width * scale) + dest_h = int(img.height * scale) for i in range(3): shape = (img.width, img.height) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index e8c80a6c..9739545c 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -19,7 +19,7 @@ class Script(scripts.Script): def ui(self, is_img2img): info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64) - scale_factor = gr.Slider(minimum=1, maximum=4, step=1, label='Scale Factor', value=2) + scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0) upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") return [info, overlap, upscaler_index, scale_factor] -- cgit v1.2.3 From 84dd7e8e2495c4fc2997e97f8267aa831eb90d11 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 20:30:02 +0300 Subject: error out with a readable message in chwewckpoint merger for incompatible tensor shapes (ie when trying to merge SD1.5 with SD2.0) --- modules/extras.py | 2 ++ modules/ui.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index 68939dea..5e270250 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -303,6 +303,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) result_is_inpainting_model = True else: + assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}' + theta_0[key] = theta_func2(a, b, multiplier) if save_as_half: diff --git a/modules/ui.py b/modules/ui.py index 27cd9ddd..67a51888 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1663,7 +1663,7 @@ def create_ui(): print("Error loading/saving model file:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) modules.sd_models.list_models() # to remove the potentially missing models from the list - return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] + return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] return results modelmerger_merge.click( -- cgit v1.2.3 From 8d12a729b8b036cb765cf2d87576d5ae256135c8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 20:46:51 +0300 Subject: fix possible error with accessing nonexistent setting --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 67a51888..9350a80f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -491,7 +491,7 @@ def apply_setting(key, value): return valtype = type(opts.data_labels[key].default) - oldval = opts.data[key] + oldval = opts.data.get(key, None) opts.data[key] = valtype(value) if valtype != type(None) else value if oldval != value and opts.data_labels[key].onchange is not None: opts.data_labels[key].onchange() -- cgit v1.2.3 From 251ecee6949c36e9df1d99a950b3e1af2b5fa2b6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 22:44:46 +0300 Subject: make "send to" buttons send actual dimension of the sent image rather than fields --- javascript/ui.js | 4 +-- modules/generation_parameters_copypaste.py | 58 ++++++++++++++++++++---------- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index 587dd782..d0c054d9 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -19,7 +19,7 @@ function selected_gallery_index(){ function extract_image_from_gallery(gallery){ if(gallery.length == 1){ - return gallery[0] + return [gallery[0]] } index = selected_gallery_index() @@ -28,7 +28,7 @@ function extract_image_from_gallery(gallery){ return [null] } - return gallery[index]; + return [gallery[index]]; } function args_to_array(args){ diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index d6fa822b..ec60319a 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -103,35 +103,57 @@ def bind_buttons(buttons, send_image, send_generate_info): bind_list.append([buttons, send_image, send_generate_info]) +def send_image_and_dimensions(x): + if isinstance(x, Image.Image): + img = x + else: + img = image_from_url_text(x) + + if shared.opts.send_size and isinstance(img, Image.Image): + w = img.width + h = img.height + else: + w = gr.update() + h = gr.update() + + return img, w, h + + def run_bind(): - for buttons, send_image, send_generate_info in bind_list: + for buttons, source_image_component, send_generate_info in bind_list: for tab in buttons: button = buttons[tab] - if send_image and paste_fields[tab]["init_img"]: - if type(send_image) == gr.Gallery: - button.click( - fn=lambda x: image_from_url_text(x), - _js="extract_image_from_gallery", - inputs=[send_image], - outputs=[paste_fields[tab]["init_img"]], - ) + destination_image_component = paste_fields[tab]["init_img"] + fields = paste_fields[tab]["fields"] + + destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None) + destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) + + if source_image_component and destination_image_component: + if isinstance(source_image_component, gr.Gallery): + func = send_image_and_dimensions if destination_width_component else image_from_url_text + jsfunc = "extract_image_from_gallery" else: - button.click( - fn=lambda x: x, - inputs=[send_image], - outputs=[paste_fields[tab]["init_img"]], - ) + func = send_image_and_dimensions if destination_width_component else lambda x: x + jsfunc = None + + button.click( + fn=func, + _js=jsfunc, + inputs=[source_image_component], + outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], + ) - if send_generate_info and paste_fields[tab]["fields"] is not None: + if send_generate_info and fields is not None: if send_generate_info in paste_fields: - paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (['Size-1', 'Size-2'] if shared.opts.send_size else []) + (["Seed"] if shared.opts.send_seed else []) + paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) button.click( fn=lambda *x: x, inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names], - outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names], + outputs=[field for field, name in fields if name in paste_field_names], ) else: - connect_paste(button, paste_fields[tab]["fields"], send_generate_info) + connect_paste(button, fields, send_generate_info) button.click( fn=None, -- cgit v1.2.3 From 1d7a31def8b5f4c348e2dd07536ac56cb4350614 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 06:21:53 +0300 Subject: make edit fields for sliders not get hidden by slider's label when there's not enough space --- style.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/style.css b/style.css index 924d4ae7..77551dd7 100644 --- a/style.css +++ b/style.css @@ -509,7 +509,7 @@ canvas[key="mask"] { position: absolute; right: 0.5em; top: -0.6em; - z-index: 200; + z-index: 400; width: 8em; } #quicksettings .gr-box > div > div > input.gr-text-input { -- cgit v1.2.3 From 269f6e867651cadef40d2c939a79d13291280bcd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 07:20:20 +0300 Subject: change settings UI to use vertical tabs --- modules/ui.py | 45 +++++++++++++++++---------------------------- style.css | 27 +++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 9350a80f..f8c973ba 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1489,41 +1489,34 @@ def create_ui(): return gr.update(value=value), opts.dumpjson() with gr.Blocks(analytics_enabled=False) as settings_interface: - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - result = gr.HTML() + with gr.Row(): + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + restart_gradio = gr.Button(value='Restart UI', variant='primary', elem_id="settings_restart_gradio") - settings_cols = 3 - items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) + result = gr.HTML(elem_id="settings_result") quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') quicksettings_list = [] - cols_displayed = 0 - items_displayed = 0 previous_section = None - column = None - with gr.Row(elem_id="settings").style(equal_height=False): + current_tab = None + with gr.Tabs(elem_id="settings"): for i, (k, item) in enumerate(opts.data_labels.items()): section_must_be_skipped = item.section[0] is None if previous_section != item.section and not section_must_be_skipped: - if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): - if column is not None: - column.__exit__() + elem_id, text = item.section - column = gr.Column(variant='panel') - column.__enter__() + if current_tab is not None: + current_tab.__exit__() - items_displayed = 0 - cols_displayed += 1 + current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) + current_tab.__enter__() previous_section = item.section - elem_id, text = item.section - gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: quicksettings_list.append((i, k, item)) components.append(dummy_component) @@ -1533,15 +1526,14 @@ def create_ui(): component = create_setting_component(k) component_dict[k] = component components.append(component) - items_displayed += 1 - with gr.Row(): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + if current_tab is not None: + current_tab.__exit__() - with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary', elem_id="settings_restart_gradio") + with gr.TabItem("Actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") request_notifications.click( fn=lambda: None, @@ -1578,9 +1570,6 @@ def create_ui(): outputs=[], ) - if column is not None: - column.__exit__() - interfaces = [ (txt2img_interface, "txt2img", "txt2img"), (img2img_interface, "img2img", "img2img"), diff --git a/style.css b/style.css index 77551dd7..7df4d960 100644 --- a/style.css +++ b/style.css @@ -241,6 +241,33 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s z-index: 200; } +#settings{ + display: block; +} + +#settings > div{ + border: none; + margin-left: 10em; +} + +#settings > div.flex-wrap{ + float: left; + display: block; + margin-left: 0; + width: 10em; +} + +#settings > div.flex-wrap button{ + display: block; + border: none; + text-align: left; +} + +#settings_result{ + height: 1.4em; + margin: 0 1.2em; +} + input[type="range"]{ margin: 0.5em 0 -0.3em 0; } -- cgit v1.2.3 From 18c03cdeac6272734b0c09afd3fbe47d1372dd07 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 09:04:29 +0300 Subject: styling rework to make things more compact --- modules/ui.py | 121 ++++++++++++++++++++++++----------------------- modules/ui_components.py | 7 +++ style.css | 35 ++++++++------ 3 files changed, 89 insertions(+), 74 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index f8c973ba..f787b518 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -19,7 +19,8 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, ui_components +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules.ui_components import FormRow, FormGroup, ToolButton from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts @@ -273,31 +274,27 @@ def interrogate_deepbooru(image): def create_seed_inputs(target_interface): - with gr.Row(): - with gr.Box(): - with gr.Row(elem_id=target_interface + '_seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - - with gr.Box(elem_id=target_interface + '_subseed_show_box'): + with FormRow(elem_id=target_interface + '_seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') + seed.style(container=False) + random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') + + with gr.Group(elem_id=target_interface + '_subseed_show_box'): seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) # Components to show/hide based on the 'Extra' checkbox seed_extras = [] - with gr.Row(visible=False) as seed_extra_row_1: + with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: seed_extras.append(seed_extra_row_1) - with gr.Box(): - with gr.Row(elem_id=target_interface + '_subseed_row'): - subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') + subseed.style(container=False) + random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') - with gr.Row(visible=False) as seed_extra_row_2: + with FormRow(visible=False) as seed_extra_row_2: seed_extras.append(seed_extra_row_2) seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') @@ -523,7 +520,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele return gr.update(**(args or {})) - refresh_button = ui_components.ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) refresh_button.click( fn=refresh, inputs=[], @@ -636,11 +633,11 @@ Requested path was: {f} def create_sampler_and_steps_selection(choices, tabname): if opts.samplers_in_dropdown: - with gr.Row(elem_id=f"sampler_selection_{tabname}"): + with FormRow(elem_id=f"sampler_selection_{tabname}"): sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) else: - with gr.Group(elem_id=f"sampler_selection_{tabname}"): + with FormGroup(elem_id=f"sampler_selection_{tabname}"): steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") @@ -677,29 +674,29 @@ def create_ui(): with gr.Column(variant='panel', elem_id="txt2img_settings"): steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - with gr.Row(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + with FormRow(elem_id="txt2img_checkboxes"): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") - with gr.Row(visible=False) as hr_options: + with FormRow(visible=False) as hr_options: hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - with gr.Row(equal_height=True): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - with gr.Group(elem_id="txt2img_script_container"): + with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) @@ -816,7 +813,7 @@ def create_ui(): img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) setup_progressbar(progressbar, img2img_preview, 'img2img') - with gr.Row().style(equal_height=False): + with FormRow().style(equal_height=False): with gr.Column(variant='panel', elem_id="img2img_settings"): with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: @@ -841,19 +838,23 @@ def create_ui(): init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - with gr.Row(): + with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - with gr.Row(): - mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + with FormRow(): + mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - with gr.Row(): - inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False, elem_id="img2img_inpaint_full_res") - inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' @@ -861,30 +862,30 @@ def create_ui(): img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - with gr.Row(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - with gr.Group(): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - - with gr.Row(): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - with gr.Row(): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - with gr.Group(): + with FormGroup(): cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - with gr.Group(elem_id="img2img_script_container"): + with FormRow(elem_id="img2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + with FormGroup(elem_id="img2img_script_container"): custom_inputs = modules.scripts.scripts_img2img.setup_ui() img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) @@ -1444,7 +1445,7 @@ def create_ui(): res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: - with ui_components.FormRow(): + with FormRow(): res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: diff --git a/modules/ui_components.py b/modules/ui_components.py index d0519d2d..91eb0e3d 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -16,3 +16,10 @@ class FormRow(gr.Row, gr.components.FormComponent): def get_block_name(self): return "row" + + +class FormGroup(gr.Group, gr.components.FormComponent): + """Same as gr.Row but fits inside gradio forms""" + + def get_block_name(self): + return "group" diff --git a/style.css b/style.css index 7df4d960..86a265f6 100644 --- a/style.css +++ b/style.css @@ -74,7 +74,8 @@ } [id$=_random_seed], [id$=_random_subseed], [id$=_reuse_seed], [id$=_reuse_subseed], #open_folder{ - min-width: auto; + min-width: 2.3em; + height: 2.5em; flex-grow: 0; padding-left: 0.25em; padding-right: 0.25em; @@ -86,6 +87,7 @@ [id$=_seed_row], [id$=_subseed_row]{ gap: 0.5rem; + padding: 0.6em; } [id$=_subseed_show_box]{ @@ -206,24 +208,24 @@ button{ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{ position: absolute; - top: -0.6em; + top: -0.5em; line-height: 1.2em; - padding: 0 0.5em; - margin: 0; + padding: 0; + margin: 0 0.5em; background-color: white; - border-top: 1px solid #eee; - border-left: 1px solid #eee; - border-right: 1px solid #eee; + box-shadow: 0 0 5px 5px white; z-index: 300; } .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ background-color: rgb(31, 41, 55); - border-top: 1px solid rgb(55 65 81); - border-left: 1px solid rgb(55 65 81); - border-right: 1px solid rgb(55 65 81); + box-shadow: 0 0 5px 5px rgb(31, 41, 55); +} + +#txt2img_column_batch, #img2img_column_batch{ + min-width: min(13.5em, 100%) !important; } #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ @@ -232,10 +234,6 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s margin-right: 8em; } -.gr-panel div.flex-col div.justify-between label span{ - margin: 0; -} - #settings .gr-panel div.flex-col div.justify-between div{ position: relative; z-index: 200; @@ -609,6 +607,15 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h } +#img2img_settings > div.gr-form, #txt2img_settings > div.gr-form { + padding-top: 0.9em; +} + +#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form{ + border: none; + padding-bottom: 0.5em; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. -- cgit v1.2.3 From 2bc86712ec16cada01a2353f1d978c1aabc84dbb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 09:13:35 +0300 Subject: make quicksettings UI elements appear in same order as they are listed in the setting --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index f787b518..d7b911da 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1497,7 +1497,7 @@ def create_ui(): result = gr.HTML(elem_id="settings_result") quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') + quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} quicksettings_list = [] @@ -1604,7 +1604,7 @@ def create_ui(): with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings"): - for i, k, item in quicksettings_list: + for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component -- cgit v1.2.3 From 9d4eff097deff6153c4023f158bd9fbd4f3e88b3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 10:01:06 +0300 Subject: add a button to show all setting pages --- javascript/ui.js | 11 +++++++++++ modules/ui.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/javascript/ui.js b/javascript/ui.js index d0c054d9..34406f3f 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -188,6 +188,17 @@ onUiUpdate(function(){ img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); } + + show_all_pages = gradioApp().getElementById('settings_show_all_pages') + settings_tabs = gradioApp().querySelector('#settings div') + if(show_all_pages && settings_tabs){ + settings_tabs.appendChild(show_all_pages) + show_all_pages.onclick = function(){ + gradioApp().querySelectorAll('#settings > div').forEach(function(elem){ + elem.style.display = "block"; + }) + } + } }) let txt2img_textarea, img2img_textarea = undefined; diff --git a/modules/ui.py b/modules/ui.py index d7b911da..2c92c422 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1536,6 +1536,8 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + request_notifications.click( fn=lambda: None, inputs=[], -- cgit v1.2.3 From a1cf55a9d1c82f8e56c00d549bca5c8fa069f412 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 10:39:21 +0300 Subject: add option to reorder items in main UI --- modules/shared.py | 13 ++++++ modules/ui.py | 130 +++++++++++++++++++++++++++++++++++------------------- 2 files changed, 97 insertions(+), 46 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index b65559ee..23657a93 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -109,6 +109,17 @@ restricted_opts = { "outdir_save", } +ui_reorder_categories = [ + "sampler", + "dimensions", + "cfg", + "seed", + "checkboxes", + "hires_fix", + "batch", + "scripts", +] + cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ @@ -410,7 +421,9 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), + "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), + 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index 2c92c422..f2e7c0d6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -644,6 +644,13 @@ def create_sampler_and_steps_selection(choices, tabname): return steps, sampler_index +def ordered_ui_categories(): + user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} + + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): + yield category + + def create_ui(): import modules.img2img import modules.txt2img @@ -672,32 +679,48 @@ def create_ui(): with gr.Row().style(equal_height=False): with gr.Column(variant='panel', elem_id="txt2img_settings"): - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - with FormRow(elem_id="txt2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - with FormRow(visible=False) as hr_options: - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "cfg": + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + elif category == "checkboxes": + with FormRow(elem_id="txt2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + + elif category == "hires_fix": + with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options: + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) @@ -865,28 +888,43 @@ def create_ui(): with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - with FormGroup(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - with FormRow(elem_id="img2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") - - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "cfg": + with FormGroup(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + + elif category == "checkboxes": + with FormRow(elem_id="img2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) -- cgit v1.2.3 From fda1ed184381fdf8aa81be4f64e77787f3fac1b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 12:01:32 +0300 Subject: some minor improvements for dark mode UI --- style.css | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/style.css b/style.css index 86a265f6..7296ce91 100644 --- a/style.css +++ b/style.css @@ -208,20 +208,20 @@ button{ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{ position: absolute; - top: -0.5em; + top: -0.7em; line-height: 1.2em; padding: 0; margin: 0 0.5em; background-color: white; - box-shadow: 0 0 5px 5px white; + box-shadow: 6px 0 6px 0px white, -6px 0 6px 0px white; z-index: 300; } .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ background-color: rgb(31, 41, 55); - box-shadow: 0 0 5px 5px rgb(31, 41, 55); + box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55); } #txt2img_column_batch, #img2img_column_batch{ -- cgit v1.2.3 From 9a3b0ee960b0c61c4f60e3081ae6f2098533d393 Mon Sep 17 00:00:00 2001 From: hithereai <121192995+hithereai@users.noreply.github.com> Date: Tue, 3 Jan 2023 11:22:06 +0200 Subject: update req.txt The old 'opencv-python' package is very limiting in terms of optical flow - so I propose a package change to 'opencv-contrib-python', which has more cv2.optflow methods. These are needed for optical flow trickery in auto1111 and its extensions, and it cannot be installed by an extension as only a single package of opencv needs to be installed for optical flow to work properly. Change of the main one is Inevitable. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e2c3876b..4f09385f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ gradio==3.15.0 invisible-watermark numpy omegaconf -opencv-python +opencv-contrib-python requests piexif Pillow -- cgit v1.2.3 From c0ee1488702d5a6ae35fbf7e0422f9f685394920 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 14:18:48 +0300 Subject: add support for running with gradio 3.9 installed --- modules/generation_parameters_copypaste.py | 4 ++-- modules/ui_tempdir.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index ec60319a..d94f11a3 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -7,7 +7,7 @@ from pathlib import Path import gradio as gr from modules.shared import script_path -from modules import shared +from modules import shared, ui_tempdir import tempfile from PIL import Image @@ -39,7 +39,7 @@ def quote(text): def image_from_url_text(filedata): if type(filedata) == dict and filedata["is_file"]: filename = filedata["name"] - is_in_right_dir = any([filename in fileset for fileset in shared.demo.temp_file_sets]) + is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename) assert is_in_right_dir, 'trying to open image file outside of allowed directories' return Image.open(filename) diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index 363d449d..21945235 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -1,6 +1,7 @@ import os import tempfile from collections import namedtuple +from pathlib import Path import gradio as gr @@ -12,10 +13,28 @@ from modules import shared Savedfile = namedtuple("Savedfile", ["name"]) +def register_tmp_file(gradio, filename): + if hasattr(gradio, 'temp_file_sets'): # gradio 3.15 + gradio.temp_file_sets[0] = gradio.temp_file_sets[0] | {os.path.abspath(filename)} + + if hasattr(gradio, 'temp_dirs'): # gradio 3.9 + gradio.temp_dirs = gradio.temp_dirs | {os.path.abspath(os.path.dirname(filename))} + + +def check_tmp_file(gradio, filename): + if hasattr(gradio, 'temp_file_sets'): + return any([filename in fileset for fileset in gradio.temp_file_sets]) + + if hasattr(gradio, 'temp_dirs'): + return any(Path(temp_dir).resolve() in Path(filename).resolve().parents for temp_dir in gradio.temp_dirs) + + return False + + def save_pil_to_file(pil_image, dir=None): already_saved_as = getattr(pil_image, 'already_saved_as', None) if already_saved_as and os.path.isfile(already_saved_as): - shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(already_saved_as)} + register_tmp_file(shared.demo, already_saved_as) file_obj = Savedfile(already_saved_as) return file_obj @@ -45,7 +64,7 @@ def on_tmpdir_changed(): os.makedirs(shared.opts.temp_dir, exist_ok=True) - shared.demo.temp_file_sets[0] = shared.demo.temp_file_sets[0] | {os.path.abspath(shared.opts.temp_dir)} + register_tmp_file(shared.demo, os.path.join(shared.opts.temp_dir, "x")) def cleanup_tmpdr(): -- cgit v1.2.3 From bddebe09edeb6a18f2c06986d5658a7be3a563ea Mon Sep 17 00:00:00 2001 From: Shondoit Date: Tue, 3 Jan 2023 10:26:37 +0100 Subject: Save Optimizer next to TI embedding Also add check to load only .PT and .BIN files as embeddings. (since we add .optim files in the same directory) --- modules/shared.py | 2 +- modules/textual_inversion/textual_inversion.py | 40 ++++++++++++++++++++------ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 23657a93..c541d18c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -355,7 +355,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), - "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fd253477..16176e90 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -28,6 +28,7 @@ class Embedding: self.cached_checksum = None self.sd_checkpoint = None self.sd_checkpoint_name = None + self.optimizer_state_dict = None def save(self, filename): embedding_data = { @@ -41,6 +42,13 @@ class Embedding: torch.save(embedding_data, filename) + if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: + optimizer_saved_dict = { + 'hash': self.checksum(), + 'optimizer_state_dict': self.optimizer_state_dict, + } + torch.save(optimizer_saved_dict, filename + '.optim') + def checksum(self): if self.cached_checksum is not None: return self.cached_checksum @@ -95,9 +103,10 @@ class EmbeddingDatabase: self.expected_shape = self.get_expected_shape() def process_file(path, filename): - name = os.path.splitext(filename)[0] + name, ext = os.path.splitext(filename) + ext = ext.upper() - if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: embed_image = Image.open(path) if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) @@ -105,8 +114,10 @@ class EmbeddingDatabase: else: data = extract_image_data_embed(embed_image) name = data.get('name', name) - else: + elif ext in ['.BIN', '.PT']: data = torch.load(path, map_location="cpu") + else: + return # textual inversion embeddings if 'string_to_param' in data: @@ -300,6 +311,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ embedding.vec.requires_grad = True optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) + if shared.opts.save_optimizer_state: + optimizer_state_dict = None + if os.path.exists(filename + '.optim'): + optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu') + if embedding.checksum() == optimizer_saved_dict.get('hash', None): + optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) + + if optimizer_state_dict is not None: + optimizer.load_state_dict(optimizer_state_dict) + print("Loaded existing optimizer from checkpoint") + else: + print("No saved optimizer exists in checkpoint") + + scaler = torch.cuda.amp.GradScaler() batch_size = ds.batch_size @@ -366,9 +391,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') - #if shared.opts.save_optimizer_state: - #embedding.optimizer_state_dict = optimizer.state_dict() - save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) + save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) embedding_yet_to_be_embedded = True write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { @@ -458,7 +481,7 @@ Last saved image: {html.escape(last_saved_image)}

""" filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') - save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) + save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: print(traceback.format_exc(), file=sys.stderr) pass @@ -470,7 +493,7 @@ Last saved image: {html.escape(last_saved_image)}
return embedding, filename -def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True): +def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): old_embedding_name = embedding.name old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None @@ -481,6 +504,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache if remove_cached_checksum: embedding.cached_checksum = None embedding.name = embedding_name + embedding.optimizer_state_dict = optimizer.state_dict() embedding.save(filename) except: embedding.sd_checkpoint = old_sd_checkpoint -- cgit v1.2.3 From e9fb9bb0c25f59109a816fc53c385bed58965c24 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 17:40:20 +0300 Subject: fix hires fix not working in API when user does not specify upscaler --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 4654570c..a172af0b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -685,7 +685,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode + latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") if self.enable_hr and latent_scale_mode is None: assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" -- cgit v1.2.3 From aaa4c2aacbb6523077334093c81bd475d757f7a1 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Tue, 3 Jan 2023 09:45:16 -0500 Subject: add api logging --- modules/api/api.py | 24 +++++++++++++++++++++++- modules/shared.py | 1 + 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9c670f00..53135470 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -1,11 +1,12 @@ import base64 import io import time +import datetime import uvicorn from threading import Lock from io import BytesIO from gradio.processing_utils import decode_base64_to_file -from fastapi import APIRouter, Depends, FastAPI, HTTPException +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest @@ -67,6 +68,26 @@ def encode_pil_to_base64(image): bytes_data = output_bytes.getvalue() return base64.b64encode(bytes_data) +def init_api_middleware(app: FastAPI): + @app.middleware("http") + async def log_and_time(req: Request, call_next): + ts = time.time() + res: Response = await call_next(req) + duration = str(round(time.time() - ts, 4)) + res.headers["X-Process-Time"] = duration + if shared.cmd_opts.api_log: + print('API {t} {code} {prot}/{ver} {method} {p} {cli} {duration}'.format( + t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), + code = res.status_code, + ver = req.scope.get('http_version', '0.0'), + cli = req.scope.get('client', ('0:0.0.0', 0))[0], + prot = req.scope.get('scheme', 'err'), + method = req.scope.get('method', 'err'), + p = req.scope.get('path', 'err'), + duration = duration, + )) + return res + class Api: def __init__(self, app: FastAPI, queue_lock: Lock): @@ -78,6 +99,7 @@ class Api: self.router = APIRouter() self.app = app + init_api_middleware(self.app) self.queue_lock = queue_lock self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) diff --git a/modules/shared.py b/modules/shared.py index 23657a93..2a03d716 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -82,6 +82,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--api-log", action='store_true', help="use api-log=True to enable logging of all API requests") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the API instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) -- cgit v1.2.3 From 1d9dc48efda2e8da6d13fc62e65500198a9b041c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Tue, 3 Jan 2023 10:21:51 -0500 Subject: init job and add info to model merge --- modules/extras.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 5e270250..7e222313 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -242,6 +242,9 @@ def run_pnginfo(image): def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): + shared.state.begin() + shared.state.job = 'model-merge' + def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -263,8 +266,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam theta_func1, theta_func2 = theta_funcs[interp_method] if theta_func1 and not tertiary_model_info: + shared.state.textinfo = "Failed: Interpolation method requires a tertiary model." + shared.state.end() return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + shared.state.textinfo = f"Loading {secondary_model_info.filename}..." print(f"Loading {secondary_model_info.filename}...") theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') @@ -281,6 +287,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam theta_1[key] = torch.zeros_like(theta_1[key]) del theta_2 + shared.state.textinfo = f"Loading {primary_model_info.filename}..." print(f"Loading {primary_model_info.filename}...") theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') @@ -291,6 +298,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam a = theta_0[key] b = theta_1[key] + shared.state.textinfo = f'Merging layer {key}' # this enables merging an inpainting model (A) with another one (B); # where normal model would have 4 channels, for latenst space, inpainting model would # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 @@ -303,8 +311,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) result_is_inpainting_model = True else: - assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}' - theta_0[key] = theta_func2(a, b, multiplier) if save_as_half: @@ -332,6 +338,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam output_modelname = os.path.join(ckpt_dir, filename) + shared.state.textinfo = f"Saving to {output_modelname}..." print(f"Saving to {output_modelname}...") _, extension = os.path.splitext(output_modelname) @@ -343,4 +350,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam sd_models.list_models() print("Checkpoint saved.") + shared.state.textinfo = "Checkpoint saved to " + output_modelname + shared.state.end() + return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] -- cgit v1.2.3 From 192ddc04d6de0d780f73aa5fbaa8c66cd4642e1c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Tue, 3 Jan 2023 10:34:51 -0500 Subject: add job info to modules --- modules/extras.py | 17 +++++++++++++---- modules/hypernetworks/hypernetwork.py | 1 + modules/textual_inversion/preprocess.py | 1 + modules/textual_inversion/textual_inversion.py | 1 + 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 7e222313..d665440a 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -58,6 +58,9 @@ cached_images: LruCache = LruCache(max_size=5) def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): devices.torch_gc() + shared.state.begin() + shared.state.job = 'extras' + imageArr = [] # Also keep track of original file names imageNameArr = [] @@ -94,6 +97,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ # Extra operation definitions def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-gfpgan' restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) res = Image.fromarray(restored_img) @@ -104,6 +108,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ return (res, info) def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-codeformer' restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) res = Image.fromarray(restored_img) @@ -114,6 +119,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ return (res, info) def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): + shared.state.job = 'extras-upscale' upscaler = shared.sd_upscalers[scaler_index] res = upscaler.scaler.upscale(image, resize, upscaler.data_path) if mode == 1 and crop: @@ -180,6 +186,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ for image, image_name in zip(imageArr, imageNameArr): if image is None: return outputs, "Please select an input image.", '' + + shared.state.textinfo = f'Processing image {image_name}' + existing_pnginfo = image.info or {} image = image.convert("RGB") @@ -193,6 +202,10 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ else: basename = '' + if opts.enable_pnginfo: # append info before save + image.info = existing_pnginfo + image.info["extras"] = info + if save_output: # Add upscaler name as a suffix. suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" @@ -203,10 +216,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - if opts.enable_pnginfo: - image.info = existing_pnginfo - image.info["extras"] = info - if extras_mode != 2 or show_extras_results : outputs.append(image) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 109e8078..450fecac 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -417,6 +417,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, shared.loaded_hypernetwork = Hypernetwork() shared.loaded_hypernetwork.load(path) + shared.state.job = "train-hypernetwork" shared.state.textinfo = "Initializing hypernetwork training..." shared.state.job_count = steps diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 56b9b2eb..feb876c6 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre files = listfiles(src) + shared.state.job = "preprocess" shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fd253477..2c1251d6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -245,6 +245,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + shared.state.job = "train-embedding" shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps -- cgit v1.2.3 From 2d5a5076bb2a0c05cc27d75a1bcadab7f32a46d0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 18:38:21 +0300 Subject: Make it so that upscalers are not repeated when restarting UI. --- modules/modelloader.py | 20 ++++++++++++++++++++ webui.py | 14 +++++++------- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index e647f6fa..6a1a7ac8 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): pass +builtin_upscaler_classes = [] +forbidden_upscaler_classes = set() + + +def list_builtin_upscalers(): + load_upscalers() + + builtin_upscaler_classes.clear() + builtin_upscaler_classes.extend(Upscaler.__subclasses__()) + + +def forbid_loaded_nonbuiltin_upscalers(): + for cls in Upscaler.__subclasses__(): + if cls not in builtin_upscaler_classes: + forbidden_upscaler_classes.add(cls) + + def load_upscalers(): # We can only do this 'magic' method to dynamically load upscalers if they are referenced, # so we'll try to import any _model.py files before looking in __subclasses__ @@ -139,6 +156,9 @@ def load_upscalers(): datas = [] commandline_options = vars(shared.cmd_opts) for cls in Upscaler.__subclasses__(): + if cls in forbidden_upscaler_classes: + continue + name = cls.__name__ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" scaler = cls(commandline_options.get(cmd_name, None)) diff --git a/webui.py b/webui.py index 3aee8792..c7d55a97 100644 --- a/webui.py +++ b/webui.py @@ -1,4 +1,5 @@ import os +import sys import threading import time import importlib @@ -55,8 +56,8 @@ def initialize(): gfpgan.setup_model(cmd_opts.gfpgan_models_path) shared.face_restorers.append(modules.face_restoration.FaceRestoration()) + modelloader.list_builtin_upscalers() modules.scripts.load_scripts() - modelloader.load_upscalers() modules.sd_vae.refresh_vae_list() @@ -169,23 +170,22 @@ def webui(): modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo) + print('Restarting UI...') sd_samplers.set_samplers() - print('Reloading extensions') extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) - print('Reloading custom scripts') + modelloader.forbid_loaded_nonbuiltin_upscalers() modules.scripts.reload_scripts() modelloader.load_upscalers() - print('Reloading modules: modules.ui') - importlib.reload(modules.ui) - print('Refreshing Model List') + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + modules.sd_models.list_models() - print('Restarting Gradio') if __name__ == "__main__": -- cgit v1.2.3 From 8f96f9289981a66741ba770d14f3d27ce335a0fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 18:39:14 +0300 Subject: call script callbacks for reloaded model after loading embeddings --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index bff8d6c9..b98b05fc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -324,12 +324,12 @@ def load_model(checkpoint_info=None): sd_model.eval() shared.sd_model = sd_model + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + script_callbacks.model_loaded_callback(sd_model) print("Model loaded.") - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model - return sd_model -- cgit v1.2.3 From cec209981ee988536c2521297baf9bc1b256005f Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Tue, 3 Jan 2023 10:58:52 -0500 Subject: log only sdapi --- modules/api/api.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 53135470..78751c57 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -68,22 +68,23 @@ def encode_pil_to_base64(image): bytes_data = output_bytes.getvalue() return base64.b64encode(bytes_data) -def init_api_middleware(app: FastAPI): +def api_middleware(app: FastAPI): @app.middleware("http") async def log_and_time(req: Request, call_next): ts = time.time() res: Response = await call_next(req) duration = str(round(time.time() - ts, 4)) res.headers["X-Process-Time"] = duration - if shared.cmd_opts.api_log: - print('API {t} {code} {prot}/{ver} {method} {p} {cli} {duration}'.format( + endpoint = req.scope.get('path', 'err') + if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'): + print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), code = res.status_code, ver = req.scope.get('http_version', '0.0'), cli = req.scope.get('client', ('0:0.0.0', 0))[0], prot = req.scope.get('scheme', 'err'), method = req.scope.get('method', 'err'), - p = req.scope.get('path', 'err'), + endpoint = endpoint, duration = duration, )) return res -- cgit v1.2.3 From d8d206c1685d1e7027d4af82ed18d106f41d1cc4 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Tue, 3 Jan 2023 11:01:04 -0500 Subject: add state to interrogate --- modules/interrogate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 6f761c5a..738d8ff7 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -136,7 +136,8 @@ class InterrogateModels: def interrogate(self, pil_image): res = "" - + shared.state.begin() + shared.state.job = 'interrogate' try: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -177,5 +178,6 @@ class InterrogateModels: res += "" self.unload() + shared.state.end() return res -- cgit v1.2.3 From 82cfc227d735c140447d5b8dca29a71ee9bde127 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 20:23:17 +0300 Subject: added licenses screen to settings added footer removed unused inpainting code --- README.md | 2 + html/footer.html | 9 + html/licenses.html | 392 ++++++++++++++++++++++++++++++++++++++++ modules/sd_hijack_inpainting.py | 232 ------------------------ modules/ui.py | 15 +- style.css | 11 ++ 6 files changed, 427 insertions(+), 234 deletions(-) create mode 100644 html/footer.html create mode 100644 html/licenses.html diff --git a/README.md b/README.md index 556000fb..88250a6b 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,8 @@ Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). ## Credits +Licenses for borrowed code can be found in `Settings -> Licenses` screen, and also in `html/licenses.html` file. + - Stable Diffusion - https://github.com/CompVis/stable-diffusion, https://github.com/CompVis/taming-transformers - k-diffusion - https://github.com/crowsonkb/k-diffusion.git - GFPGAN - https://github.com/TencentARC/GFPGAN.git diff --git a/html/footer.html b/html/footer.html new file mode 100644 index 00000000..a8f2adf7 --- /dev/null +++ b/html/footer.html @@ -0,0 +1,9 @@ +
+ API +  •  + Github +  •  + Gradio +  •  + Reload UI +
diff --git a/html/licenses.html b/html/licenses.html new file mode 100644 index 00000000..9eeaa072 --- /dev/null +++ b/html/licenses.html @@ -0,0 +1,392 @@ + + +

CodeFormer

+Parts of CodeFormer code had to be copied to be compatible with GFPGAN. +
+S-Lab License 1.0
+
+Copyright 2022 S-Lab
+
+Redistribution and use for non-commercial purpose in source and
+binary forms, with or without modification, are permitted provided
+that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in
+   the documentation and/or other materials provided with the
+   distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+   contributors may be used to endorse or promote products derived
+   from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+In the event that redistribution and/or use for commercial purpose in
+source or binary forms, with or without modification is required,
+please contact the contributor(s) of the work.
+
+ + +

ESRGAN

+Code for architecture and reading models copied. +
+MIT License
+
+Copyright (c) 2021 victorca25
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+ +

Real-ESRGAN

+Some code is copied to support ESRGAN models. +
+BSD 3-Clause License
+
+Copyright (c) 2021, Xintao Wang
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+   list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright notice,
+   this list of conditions and the following disclaimer in the documentation
+   and/or other materials provided with the distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+   contributors may be used to endorse or promote products derived from
+   this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+ +

InvokeAI

+Some code for compatibility with OSX is taken from lstein's repository. +
+MIT License
+
+Copyright (c) 2022 InvokeAI Team
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+ +

LDSR

+Code added by contirubtors, most likely copied from this repository. +
+MIT License
+
+Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+ +

CLIP Interrogator

+Some small amounts of code borrowed and reworked. +
+MIT License
+
+Copyright (c) 2022 pharmapsychotic
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+ +

SwinIR

+Code added by contirubtors, most likely copied from this repository. + +
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [2021] [SwinIR Authors]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+
+ diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 06b75772..3c214a35 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -12,191 +12,6 @@ from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.ddim import DDIMSampler, noise_like -# ================================================================================================= -# Monkey patch DDIMSampler methods from RunwayML repo directly. -# Adapted from: -# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py -# ================================================================================================= -@torch.no_grad() -def sample_ddim(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): - ctmp = ctmp[0] - cbs = ctmp.shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - - samples, intermediates = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return samples, intermediates - -@torch.no_grad() -def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None): - b, *_, device = *x.shape, x.device - - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - if isinstance(c, dict): - assert isinstance(unconditional_conditioning, dict) - c_in = dict() - for k in c: - if isinstance(c[k], list): - c_in[k] = [ - torch.cat([unconditional_conditioning[k][i], c[k][i]]) - for i in range(len(c[k])) - ] - else: - c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) - else: - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - -# ================================================================================================= -# Monkey patch PLMSSampler methods. -# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes. -# Adapted from: -# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py -# ================================================================================================= -@torch.no_grad() -def sample_plms(self, - S, - batch_size, - shape, - conditioning=None, - callback=None, - normals_sequence=None, - img_callback=None, - quantize_x0=False, - eta=0., - mask=None, - x0=None, - temperature=1., - noise_dropout=0., - score_corrector=None, - corrector_kwargs=None, - verbose=True, - x_T=None, - log_every_t=100, - unconditional_guidance_scale=1., - unconditional_conditioning=None, - # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... - **kwargs - ): - if conditioning is not None: - if isinstance(conditioning, dict): - ctmp = conditioning[list(conditioning.keys())[0]] - while isinstance(ctmp, list): - ctmp = ctmp[0] - cbs = ctmp.shape[0] - if cbs != batch_size: - print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") - else: - if conditioning.shape[0] != batch_size: - print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") - - self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) - # sampling - C, H, W = shape - size = (batch_size, C, H, W) - # print(f'Data shape for PLMS sampling is {size}') # remove unnecessary message - - samples, intermediates = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, - mask=mask, x0=x0, - ddim_use_original_steps=False, - noise_dropout=noise_dropout, - temperature=temperature, - score_corrector=score_corrector, - corrector_kwargs=corrector_kwargs, - x_T=x_T, - log_every_t=log_every_t, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning, - ) - return samples, intermediates - @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, @@ -280,44 +95,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0, e_t -# ================================================================================================= -# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config. -# Adapted from: -# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py -# ================================================================================================= - -@torch.no_grad() -def get_unconditional_conditioning(self, batch_size, null_label=None): - if null_label is not None: - xc = null_label - if isinstance(xc, ListConfig): - xc = list(xc) - if isinstance(xc, dict) or isinstance(xc, list): - c = self.get_learned_conditioning(xc) - else: - if hasattr(xc, "to"): - xc = xc.to(self.device) - c = self.get_learned_conditioning(xc) - else: - # todo: get null label from cond_stage_model - raise NotImplementedError() - c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) - return c - - -class LatentInpaintDiffusion(LatentDiffusion): - def __init__( - self, - concat_keys=("mask", "masked_image"), - masked_image_key="masked_image", - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.masked_image_key = masked_image_key - assert self.masked_image_key in concat_keys - self.concat_keys = concat_keys - def should_hijack_inpainting(checkpoint_info): ckpt_basename = os.path.basename(checkpoint_info.filename).lower() @@ -326,15 +103,6 @@ def should_hijack_inpainting(checkpoint_info): def do_inpainting_hijack(): - # most of this stuff seems to no longer be needed because it is already included into SD2.0 # p_sample_plms is needed because PLMS can't work with dicts as conditionings - # this file should be cleaned up later if everything turns out to work fine - - # ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning - # ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion - - # ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim - # ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms - # ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms diff --git a/modules/ui.py b/modules/ui.py index f2e7c0d6..d941cb5f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1529,8 +1529,10 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Row(): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - restart_gradio = gr.Button(value='Restart UI', variant='primary', elem_id="settings_restart_gradio") + with gr.Column(scale=6): + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") result = gr.HTML(elem_id="settings_result") @@ -1574,6 +1576,11 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + if os.path.exists("html/licenses.html"): + with open("html/licenses.html", encoding="utf8") as file: + with gr.TabItem("Licenses"): + gr.HTML(file.read(), elem_id="licenses") + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") request_notifications.click( @@ -1659,6 +1666,10 @@ def create_ui(): if os.path.exists(os.path.join(script_path, "notification.mp3")): audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + if os.path.exists("html/footer.html"): + with open("html/footer.html", encoding="utf8") as file: + gr.HTML(file.read(), elem_id="footer") + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) settings_submit.click( fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), diff --git a/style.css b/style.css index 7296ce91..2116ec3c 100644 --- a/style.css +++ b/style.css @@ -616,6 +616,17 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h padding-bottom: 0.5em; } +footer { + display: none !important; +} + +#footer{ + text-align: center; +} + +#footer div{ + display: inline-block; +} /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. -- cgit v1.2.3 From 7c89f3718f9f078113833a88a86f02d3205855b4 Mon Sep 17 00:00:00 2001 From: MMaker Date: Tue, 3 Jan 2023 12:46:48 -0500 Subject: Add image paste fallback Fixes Firefox pasting support (and possibly other browsers) --- javascript/dragdrop.js | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 3ed1cb3c..fe008924 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -9,11 +9,19 @@ function dropReplaceImage( imgWrap, files ) { return; } + const tmpFile = files[0]; + imgWrap.querySelector('.modify-upload button + button, .touch-none + div button + button')?.click(); const callback = () => { const fileInput = imgWrap.querySelector('input[type="file"]'); if ( fileInput ) { - fileInput.files = files; + if ( files.length === 0 ) { + files = new DataTransfer(); + files.items.add(tmpFile); + fileInput.files = files.files; + } else { + fileInput.files = files; + } fileInput.dispatchEvent(new Event('change')); } }; -- cgit v1.2.3 From 3e22e294135ed0327ce9d9738655ff03c53df3c0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 21:49:24 +0300 Subject: fix broken send to extras button --- modules/generation_parameters_copypaste.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index d94f11a3..4baf4d9a 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -37,7 +37,10 @@ def quote(text): def image_from_url_text(filedata): - if type(filedata) == dict and filedata["is_file"]: + if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): + filedata = filedata[0] + + if type(filedata) == dict and filedata.get("is_file", False): filename = filedata["name"] is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename) assert is_in_right_dir, 'trying to open image file outside of allowed directories' -- cgit v1.2.3 From 917b5bd8d0cd47c9dc241c1852ccd440a8c61668 Mon Sep 17 00:00:00 2001 From: Max Weber Date: Tue, 3 Jan 2023 18:19:56 -0700 Subject: ui: save dropdown sampling method to the ui-config --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/ui.py b/modules/ui.py index d941cb5f..bfc93634 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -635,6 +635,7 @@ def create_sampler_and_steps_selection(choices, tabname): if opts.samplers_in_dropdown: with FormRow(elem_id=f"sampler_selection_{tabname}"): sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + sampler_index.save_to_config = True steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) else: with FormGroup(elem_id=f"sampler_selection_{tabname}"): -- cgit v1.2.3 From 4fc81542077af73610279ad7b6b26e38718a0f81 Mon Sep 17 00:00:00 2001 From: Gerschel Date: Tue, 3 Jan 2023 23:25:34 -0800 Subject: better targetting, class tabs was autoassigned I moved a preset manager into quicksettings, this function was targeting my component instead of the tabs. This is because class tabs is autoassigned, while element id #tabs is not, this allows a tabbed component to live in the quicksettings. --- script.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/script.js b/script.js index 9748ec90..0e117d06 100644 --- a/script.js +++ b/script.js @@ -4,7 +4,7 @@ function gradioApp() { } function get_uiCurrentTab() { - return gradioApp().querySelector('.tabs button:not(.border-transparent)') + return gradioApp().querySelector('#tabs button:not(.border-transparent)') } function get_uiCurrentTabContent() { -- cgit v1.2.3 From e5b7ee910e7bb88f08e8876b5732cb034c6fe529 Mon Sep 17 00:00:00 2001 From: MMaker Date: Wed, 4 Jan 2023 04:22:01 -0500 Subject: fix: Save full res of intermediate step --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index a172af0b..93e75ba6 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -705,7 +705,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return if not isinstance(image, Image.Image): - image = sd_samplers.sample_to_image(image, index) + image = sd_samplers.sample_to_image(image, index, approximation=0) images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") -- cgit v1.2.3 From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:35:07 +0300 Subject: helpful error message when trying to load 2.0 without config failing to load model weights from settings won't break generation for currently loaded model anymore --- modules/errors.py | 25 +++++++++++++++++++++++-- modules/sd_models.py | 26 ++++++++++++++++++-------- modules/shared.py | 9 +++++++-- webui.py | 12 ++++++++++-- 4 files changed, 58 insertions(+), 14 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index 372dc51a..a668c014 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -2,9 +2,30 @@ import sys import traceback +def print_error_explanation(message): + lines = message.strip().split("\n") + max_len = max([len(x) for x in lines]) + + print('=' * max_len, file=sys.stderr) + for line in lines: + print(line, file=sys.stderr) + print('=' * max_len, file=sys.stderr) + + +def display(e: Exception, task): + print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + message = str(e) + if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: + print_error_explanation(""" +The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file. +See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this. + """) + + def run(code, task): try: code() except Exception as e: - print(f"{task}: {type(e).__name__}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + display(task, e) diff --git a/modules/sd_models.py b/modules/sd_models.py index b98b05fc..6846b74a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -278,6 +278,7 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -312,6 +313,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.use_fp16 = False sd_model = instantiate_from_config(sd_config.model) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -336,10 +338,12 @@ def load_model(checkpoint_info=None): def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - + if not sd_model: sd_model = shared.sd_model + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return @@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) - - sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) - - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: - sd_model.to(devices.device) + try: + load_model_weights(sd_model, checkpoint_info) + except Exception as e: + print("Failed to load checkpoint, restoring previous") + load_model_weights(sd_model, current_checkpoint_info) + raise + finally: + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) print("Weights loaded.") + return sd_model diff --git a/modules/shared.py b/modules/shared.py index 23657a93..7588c47b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading +from modules import localization, sd_vae, extensions, script_loading, errors from modules.paths import models_path, script_path, sd_path @@ -494,7 +494,12 @@ class Options: return False if self.data_labels[key].onchange is not None: - self.data_labels[key].onchange() + try: + self.data_labels[key].onchange() + except Exception as e: + errors.display(e, f"changing setting {key} to {value}") + setattr(self, key, oldval) + return False return True diff --git a/webui.py b/webui.py index c7d55a97..13375e71 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook +from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path @@ -61,7 +61,15 @@ def initialize(): modelloader.load_upscalers() modules.sd_vae.refresh_vae_list() - modules.sd_models.load_model() + + try: + modules.sd_models.load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load, exiting", file=sys.stderr) + exit(1) + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) -- cgit v1.2.3 From 8d8a05a3bbb50fdfeab51679a919d2487bd97976 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:47:42 +0300 Subject: find configs for models at runtime rather than when starting --- modules/sd_hijack_inpainting.py | 5 ++++- modules/sd_models.py | 31 ++++++++++++++++++------------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 3c214a35..31d2c898 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F def should_hijack_inpainting(checkpoint_info): + from modules import sd_models + ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(checkpoint_info.config).lower() + cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() + return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename diff --git a/modules/sd_models.py b/modules/sd_models.py index 6846b74a..6dca4ddf 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} checkpoints_loaded = collections.OrderedDict() @@ -48,6 +48,14 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) +def find_checkpoint_config(info): + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return shared.cmd_opts.config + + def list_models(): checkpoints_list.clear() model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) @@ -73,7 +81,7 @@ def list_models(): if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) @@ -81,12 +89,7 @@ def list_models(): h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - basename, _ = os.path.splitext(filename) - config = basename + ".yaml" - if not os.path.exists(config): - config = shared.cmd_opts.config - - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) def get_closet_checkpoint_match(searchString): @@ -282,9 +285,10 @@ def enable_midas_autodownload(): def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + checkpoint_config = find_checkpoint_config(checkpoint_info) - if checkpoint_info.config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_info.config}") + if checkpoint_config != shared.cmd_opts.config: + print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -292,7 +296,7 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_info.config) + sd_config = OmegaConf.load(checkpoint_config) if should_hijack_inpainting(checkpoint_info): # Hardcoded config for now... @@ -302,7 +306,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.finetune_keys = None # Create a "fake" config with a different name so that we know to unload it when switching models. - checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml")) if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None): sd_model = shared.sd_model current_checkpoint_info = sd_model.sd_checkpoint_info + checkpoint_config = find_checkpoint_config(current_checkpoint_info) if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.3 From 96cf15bedecbed97ef9b70b8413d543a9aee5adf Mon Sep 17 00:00:00 2001 From: MMaker Date: Wed, 4 Jan 2023 05:12:06 -0500 Subject: Add new latent upscale modes --- modules/shared.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 7588c47b..a10f69a9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -564,8 +564,11 @@ if os.path.exists(config_filename): latent_upscale_default_mode = "Latent" latent_upscale_modes = { - "Latent": "bilinear", - "Latent (nearest)": "nearest", + "Latent": {"mode": "bilinear", "antialias": False}, + "Latent (antialiased)": {"mode": "bilinear", "antialias": True}, + "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, + "Latent (bicubic, antialiased)": {"mode": "bicubic", "antialias": True}, + "Latent (nearest)": {"mode": "nearest", "antialias": False}, } sd_upscalers = [] -- cgit v1.2.3 From 15fd0b8bc4734ea85bca1acfb12b51465ab9817d Mon Sep 17 00:00:00 2001 From: MMaker Date: Wed, 4 Jan 2023 05:12:54 -0500 Subject: Update processing.py --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index a172af0b..7c72b56a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -713,7 +713,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode) + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"]) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. -- cgit v1.2.3 From 4ec6470a1a2d9430b91266426f995e48f59564e1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 13:26:23 +0300 Subject: fix checkpoint list API --- modules/api/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 9c670f00..2b1f180c 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list +from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List @@ -303,7 +303,7 @@ class Api: return upscalers def get_sd_models(self): - return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()] + return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] -- cgit v1.2.3 From b2151b934fe0a3613570c6abd7615d3788fd1c8f Mon Sep 17 00:00:00 2001 From: MMaker Date: Wed, 4 Jan 2023 05:36:18 -0500 Subject: Rename bicubic antialiased option Comma was causing the the value in PNG info to be quoted, which causes the upscaler dropdown option to be blank when sending to UI --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index a10f69a9..c1b20081 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -567,7 +567,7 @@ latent_upscale_modes = { "Latent": {"mode": "bilinear", "antialias": False}, "Latent (antialiased)": {"mode": "bilinear", "antialias": True}, "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, - "Latent (bicubic, antialiased)": {"mode": "bicubic", "antialias": True}, + "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True}, "Latent (nearest)": {"mode": "nearest", "antialias": False}, } -- cgit v1.2.3 From 3bd737767b071878ea980e94b8705f603bcf545e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 14:20:32 +0300 Subject: disable broken API logging --- modules/api/api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index a6c1d6ed..6267afdc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -100,7 +100,6 @@ class Api: self.router = APIRouter() self.app = app - init_api_middleware(self.app) self.queue_lock = queue_lock self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) -- cgit v1.2.3 From 0cd6399b8b1699b8b7acad6f0ad2988111fe618e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 14:29:13 +0300 Subject: fix broken inpainting model --- modules/sd_models.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 6dca4ddf..a568823d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -305,9 +305,6 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None - # Create a "fake" config with a different name so that we know to unload it when switching models. - checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml")) - if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False -- cgit v1.2.3 From 11b8160a086c434d5baf4971edda46e6d2126800 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 4 Jan 2023 06:36:57 -0500 Subject: fix typo --- modules/api/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/api/api.py b/modules/api/api.py index 6267afdc..48a70a44 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -101,6 +101,7 @@ class Api: self.router = APIRouter() self.app = app self.queue_lock = queue_lock + api_middleware(self.app) self.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse) self.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse) self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) -- cgit v1.2.3 From 642142556d8ecdea9beb86d7618b628b1803ab98 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 15:09:53 +0300 Subject: use commandline-supplied cuda device name instead of cuda:0 for safetensors PR that doesn't fix anything --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index ee918f24..76a89e88 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -173,7 +173,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None if extension.lower() == ".safetensors": device = map_location or shared.weight_load_location if device is None: - device = "cuda:0" if torch.cuda.is_available() else "cpu" + device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.3 From 21ee77db314ede7ccbb18787962347c09a4df0c7 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 4 Jan 2023 08:04:38 -0500 Subject: add cross-attention info --- modules/sd_hijack.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index edcbaf52..fa2cd4bb 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -35,26 +35,35 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + + optimization_method = None if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + optimization_method = 'xformers' elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + optimization_method = 'V1' elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): if not invokeAI_mps_available and shared.device.type == 'mps': print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + optimization_method = 'V1' else: print("Applying cross attention optimization (InvokeAI).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI + optimization_method = 'InvokeAI' elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + optimization_method = 'Doggettx' + + return optimization_method def undo_optimizations(): @@ -75,6 +84,7 @@ class StableDiffusionModelHijack: layers = None circular_enabled = False clip = None + optimization_method = None embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) @@ -94,7 +104,7 @@ class StableDiffusionModelHijack: m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - apply_optimizations() + self.optimization_method = apply_optimizations() self.clip = m.cond_stage_model -- cgit v1.2.3 From 1cfd8aec4ae5a6ca1afd67b44cb4ef6dd14d8c34 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 16:05:42 +0300 Subject: make it possible to work with opts.show_progress_every_n_steps = -1 with medvram --- modules/shared.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 4fcc6edd..54a6ba23 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -214,12 +214,13 @@ class State: """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" def set_current_image(self): + if not parallel_processing_allowed: + return + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: self.do_set_current_image() def do_set_current_image(self): - if not parallel_processing_allowed: - return if self.current_latent is None: return @@ -231,6 +232,7 @@ class State: self.current_image_sampling_step = self.sampling_step + state = State() artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) -- cgit v1.2.3 From 79c682ad4f2d982b26fa1a15044582d1005134f9 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 4 Jan 2023 08:20:42 -0500 Subject: fix jpeg --- modules/extras.py | 2 -- modules/images.py | 2 ++ requirements_versions.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index d665440a..7407bfe3 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -19,8 +19,6 @@ from modules.shared import opts import modules.gfpgan_model from modules.ui import plaintext_to_html import modules.codeformer_model -import piexif -import piexif.helper import gradio as gr import safetensors.torch diff --git a/modules/images.py b/modules/images.py index c3a5fc8b..a73be3fa 100644 --- a/modules/images.py +++ b/modules/images.py @@ -22,6 +22,8 @@ from modules.shared import opts, cmd_opts LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +Image.init() # initialize once all known file format handlers + def image_grid(imgs, batch_size=1, rows=None): if rows is None: diff --git a/requirements_versions.txt b/requirements_versions.txt index 975102d9..7ae118cb 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -5,7 +5,7 @@ basicsr==1.4.2 gfpgan==1.3.8 gradio==3.15.0 numpy==1.23.3 -Pillow==9.2.0 +Pillow==9.3.0 realesrgan==0.3.0 torch omegaconf==2.2.3 -- cgit v1.2.3 From 4d66bf2c0d27702cc83b9cc57ebb1f359d18d938 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 17:24:46 +0300 Subject: add infotext to "-before-highres-fix" images --- modules/processing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index fd7c7015..c03e77e7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -136,6 +136,7 @@ class StableDiffusionProcessing(): self.all_negative_prompts = None self.all_seeds = None self.all_subseeds = None + self.iteration = 0 def txt2img_image_conditioning(self, x, width=None, height=None): if self.sampler.conditioning_key not in {'hybrid', 'concat'}: @@ -544,6 +545,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: state.job_count = p.n_iter for n in range(p.n_iter): + p.iteration = n + if state.skipped: state.skipped = False @@ -707,7 +710,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not isinstance(image, Image.Image): image = sd_samplers.sample_to_image(image, index, approximation=0) - images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") + info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index) + images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, suffix="-before-highres-fix") if latent_scale_mode is not None: for i in range(samples.shape[0]): -- cgit v1.2.3 From 184e670126f5fc50ba56fa0fedcf0cf60e45ed7e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 17:45:01 +0300 Subject: fix the merge --- modules/textual_inversion/textual_inversion.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5421a758..8731ea5d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -251,6 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" + def create_dummy_mask(x, width=None, height=None): if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}: @@ -380,17 +381,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ break with devices.autocast(): - # c = stack_conds(batch.cond).to(devices.device) - # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory) - # print(mask) - # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory) - - - if img_c is None: - img_c = create_dummy_mask(c, training_width, training_height) - x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) c = shared.sd_model.cond_stage_model(batch.cond_text) + + if img_c is None: + img_c = create_dummy_mask(c, training_width, training_height) + cond = {"c_concat": [img_c], "c_crossattn": [c]} loss = shared.sd_model(x, cond)[0] / gradient_step del x -- cgit v1.2.3 From 590c5ae016ae494f4873ca20079b30684ea3060c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 4 Jan 2023 09:48:54 -0500 Subject: update pillow --- modules/images.py | 2 -- requirements_versions.txt | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/modules/images.py b/modules/images.py index a73be3fa..c3a5fc8b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -22,8 +22,6 @@ from modules.shared import opts, cmd_opts LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) -Image.init() # initialize once all known file format handlers - def image_grid(imgs, batch_size=1, rows=None): if rows is None: diff --git a/requirements_versions.txt b/requirements_versions.txt index 7ae118cb..d2899292 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -5,7 +5,7 @@ basicsr==1.4.2 gfpgan==1.3.8 gradio==3.15.0 numpy==1.23.3 -Pillow==9.3.0 +Pillow==9.4.0 realesrgan==0.3.0 torch omegaconf==2.2.3 -- cgit v1.2.3 From 525cea924562afd676f55470095268a0f6fca59e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 17:58:07 +0300 Subject: use shared function from processing for creating dummy mask when training inpainting model --- modules/processing.py | 39 +++++++++++++------------- modules/textual_inversion/textual_inversion.py | 33 ++++++---------------- 2 files changed, 29 insertions(+), 43 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c03e77e7..c7264aff 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays): return image +def txt2img_image_conditioning(sd_model, x, width, height): + if sd_model.model.conditioning_key not in {'hybrid', 'concat'}: + # Dummy zero conditioning if we're not using inpainting model. + # Still takes up a bit of memory, but no encoder call. + # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. + return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) + + # The "masked-image" in this case will just be all zeros since the entire image is masked. + image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) + image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + return image_conditioning + + class StableDiffusionProcessing(): """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing @@ -139,26 +157,9 @@ class StableDiffusionProcessing(): self.iteration = 0 def txt2img_image_conditioning(self, x, width=None, height=None): - if self.sampler.conditioning_key not in {'hybrid', 'concat'}: - # Dummy zero conditioning if we're not using inpainting model. - # Still takes up a bit of memory, but no encoder call. - # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. - return x.new_zeros(x.shape[0], 5, 1, 1) + self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} - self.is_using_inpainting_conditioning = True - - height = height or self.height - width = width or self.width - - # The "masked-image" in this case will just be all zeros since the entire image is masked. - image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) - - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) - - return image_conditioning + return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height) def depth2img_image_conditioning(self, source_image): # Use the AddMiDaS helper to Format our source image to suit the MiDaS model diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8731ea5d..2250e41b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -252,26 +252,6 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" -def create_dummy_mask(x, width=None, height=None): - if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}: - - # The "masked-image" in this case will just be all zeros since the entire image is masked. - image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = shared.sd_model.get_first_stage_encoding(shared.sd_model.encode_first_stage(image_conditioning)) - - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) - - else: - # Dummy zero conditioning if we're not using inpainting model. - # Still takes up a bit of memory, but no encoder call. - # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. - image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) - - return image_conditioning - - def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 @@ -346,7 +326,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ else: print("No saved optimizer exists in checkpoint") - scaler = torch.cuda.amp.GradScaler() batch_size = ds.batch_size @@ -362,7 +341,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ forced_filename = "" embedding_yet_to_be_embedded = False + is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'} img_c = None + pbar = tqdm.tqdm(total=steps - initial_step) try: for i in range((steps-initial_step) * gradient_step): @@ -384,10 +365,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) c = shared.sd_model.cond_stage_model(batch.cond_text) - if img_c is None: - img_c = create_dummy_mask(c, training_width, training_height) + if is_training_inpainting_model: + if img_c is None: + img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) + + cond = {"c_concat": [img_c], "c_crossattn": [c]} + else: + cond = c - cond = {"c_concat": [img_c], "c_crossattn": [c]} loss = shared.sd_model(x, cond)[0] / gradient_step del x -- cgit v1.2.3 From a8eb9e3bf814f72293e474c11e9ff0098859a942 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 18:20:38 +0300 Subject: Revert "Merge pull request #3791 from shirayu/fix/filename" This reverts commit eed58279e7cb0e873ebd88a29609f9bab0f1f3af, reversing changes made to 4ae960b01c6711c66985479f14809dc7fa549fc2. --- modules/images.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/modules/images.py b/modules/images.py index 2967fa9a..c3a5fc8b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -447,14 +447,6 @@ def get_next_sequence_number(path, basename): return result + 1 -def truncate_fullpath(full_path, encoding='utf-8'): - dir_name, full_name = os.path.split(full_path) - file_name, file_ext = os.path.splitext(full_name) - max_length = os.statvfs(dir_name).f_namemax - file_name_truncated = file_name.encode(encoding)[:max_length - len(file_ext)].decode(encoding, 'ignore') - return os.path.join(dir_name , file_name_truncated + file_ext) - - def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None): """Save an image. @@ -495,7 +487,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i if save_to_dirs: dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /') - path = truncate_fullpath(os.path.join(path, dirname)) + path = os.path.join(path, dirname) os.makedirs(path, exist_ok=True) @@ -519,13 +511,13 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i fullfn = None for i in range(500): fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" - fullfn = truncate_fullpath(os.path.join(path, f"{fn}{file_decoration}.{extension}")) + fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") if not os.path.exists(fullfn): break else: - fullfn = truncate_fullpath(os.path.join(path, f"{file_decoration}.{extension}")) + fullfn = os.path.join(path, f"{file_decoration}.{extension}") else: - fullfn = truncate_fullpath(os.path.join(path, f"{forced_filename}.{extension}")) + fullfn = os.path.join(path, f"{forced_filename}.{extension}") pnginfo = existing_info or {} if info is not None: -- cgit v1.2.3 From 3dae545a03f5102ba5d9c3f27bb6241824c5a916 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 18:42:51 +0300 Subject: rename weirdly named variables from #3176 --- modules/ui.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index e4859020..184af7ad 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -162,16 +162,14 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - -def calc_time_left(progress, threshold, label, force_display, showTime): +def calc_time_left(progress, threshold, label, force_display, show_eta): if progress == 0: return "" else: time_since_start = time.time() - shared.state.time_start eta = (time_since_start/progress) eta_relative = eta-time_since_start - if (eta_relative > threshold and showTime) or force_display: + if (eta_relative > threshold and show_eta) or force_display: if eta_relative > 3600: return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) elif eta_relative > 60: @@ -194,9 +192,9 @@ def check_progress_call(id_part): progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps # Show progress percentage and time left at the same moment, and base it also on steps done - showPBText = progress >= 0.01 or shared.state.sampling_step >= 10 + show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - time_left = calc_time_left( progress, 1, " ETA: ", shared.state.time_left_force_display, showPBText ) + time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) if time_left != "": shared.state.time_left_force_display = True @@ -204,7 +202,7 @@ def check_progress_call(id_part): progressbar = "" if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if showPBText else ""}
""" + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" image = gr_show(False) preview_visibility = gr_show(False) -- cgit v1.2.3 From 097a90b88bb92878cf435c513b4757b5b82ae299 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 19:19:11 +0300 Subject: add XY plot parameters to grid image and do not add them to individual images --- modules/processing.py | 2 +- scripts/xy_grid.py | 38 ++++++++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index c7264aff..47712159 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -422,7 +422,7 @@ def fix_seed(p): p.subseed = get_fixed_seed(p.subseed) -def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): +def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 59907f0b..78ff12c5 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images, paths, sd_samplers +from modules import images, paths, sd_samplers, processing from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state @@ -285,6 +285,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") + class Script(scripts.Script): def title(self): return "X/Y plot" @@ -381,7 +382,7 @@ class Script(scripts.Script): ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label in ['Seed','Var. seed']: + if axis_opt.label in ['Seed', 'Var. seed']: return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] else: return axis_list @@ -403,24 +404,33 @@ class Script(scripts.Script): print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") shared.total_tqdm.updateTotal(total_steps * p.n_iter) + grid_infotext = [None] + def cell(x, y): pc = copy(p) x_opt.apply(pc, x, xs) y_opt.apply(pc, y, ys) - return process_images(pc) + res = process_images(pc) + + if grid_infotext[0] is None: + pc.extra_generation_params = copy(pc.extra_generation_params) + + if x_opt.label != 'Nothing': + pc.extra_generation_params["X Type"] = x_opt.label + pc.extra_generation_params["X Values"] = x_values + if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) + + if y_opt.label != 'Nothing': + pc.extra_generation_params["Y Type"] = y_opt.label + pc.extra_generation_params["Y Values"] = y_values + if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) - if not x_opt.label == 'Nothing': - p.extra_generation_params["XY Plot X Type"] = x_opt.label - p.extra_generation_params["XY Plot X Values"] = '{' + x_values + '}' - if x_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: - p.extra_generation_params["XY Plot Fixed X Values"] = '{' + ", ".join([str(x) for x in xs])+ '}' + grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) - if not y_opt.label == 'Nothing': - p.extra_generation_params["XY Plot Y Type"] = y_opt.label - p.extra_generation_params["XY Plot Y Values"] = '{' + y_values + '}' - if y_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: - p.extra_generation_params["XY Plot Fixed Y Values"] = '{' + ", ".join([str(y) for y in ys])+ '}' + return res with SharedSettingsStackHelper(): processed = draw_xy_grid( @@ -435,6 +445,6 @@ class Script(scripts.Script): ) if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) return processed -- cgit v1.2.3 From 24d4a0841d3cc0e5908b098f65a9caa3fa889af8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 20:10:40 +0300 Subject: train tab visual updates allow setting train tab values from ui-config.json --- modules/ui.py | 35 +++++++++++++++++++++-------------- style.css | 2 +- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 72e7b7d2..44f4f3a4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1281,42 +1281,48 @@ def create_ui(): with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with gr.Row(): + with FormRow(): train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - with gr.Row(): + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - with gr.Row(): + + with FormRow(): embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - with gr.Row(): + with FormRow(): clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) - batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + with FormRow(): + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + + with FormRow(): + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") - with gr.Row(): - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") - with gr.Row(): - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") + + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") + + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") with gr.Row(): + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") - train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") params = script_callbacks.UiTrainTabParams(txt2img_preview_params) @@ -1803,6 +1809,7 @@ def create_ui(): visit(img2img_interface, loadsave, "img2img") visit(extras_interface, loadsave, "extras") visit(modelmerger_interface, loadsave, "modelmerger") + visit(train_interface, loadsave, "train") if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): with open(ui_config_file, "w", encoding="utf8") as file: diff --git a/style.css b/style.css index 2116ec3c..09ee540b 100644 --- a/style.css +++ b/style.css @@ -611,7 +611,7 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h padding-top: 0.9em; } -#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form{ +#img2img_settings div.gr-form .gr-form, #txt2img_settings div.gr-form .gr-form, #train_tabs div.gr-form .gr-form{ border: none; padding-bottom: 0.5em; } -- cgit v1.2.3 From 81490780949fffed77493b4bd741e96ec737fe27 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 22:04:40 +0300 Subject: added the option to specify target resolution with possibility of truncating for hires fix; also sampling steps --- javascript/hints.js | 11 ++++--- modules/generation_parameters_copypaste.py | 9 ++++-- modules/processing.py | 51 +++++++++++++++++++++++++++--- modules/txt2img.py | 5 ++- modules/ui.py | 24 ++++++++++---- 5 files changed, 81 insertions(+), 19 deletions(-) diff --git a/javascript/hints.js b/javascript/hints.js index 63e17e05..dda66e09 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -81,9 +81,6 @@ titles = { "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", - "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", - "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", - "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", @@ -100,7 +97,13 @@ titles = { "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.", "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.", - "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality." + "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality.", + + "Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", + "Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.", + "Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.", + "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", + "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders." } diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4baf4d9a..12a9de3d 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -212,11 +212,10 @@ def restore_old_hires_fix_params(res): firstpass_width = math.ceil(scale * width / 64) * 64 firstpass_height = math.ceil(scale * height / 64) * 64 - hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height - res['Size-1'] = firstpass_width res['Size-2'] = firstpass_height - res['Hires upscale'] = hr_scale + res['Hires resize-1'] = width + res['Hires resize-2'] = height def parse_generation_parameters(x: str): @@ -276,6 +275,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hypernet_hash = res.get("Hypernet hash", None) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + if "Hires resize-1" not in res: + res["Hires resize-1"] = 0 + res["Hires resize-2"] = 0 + restore_old_hires_fix_params(res) return res diff --git a/modules/processing.py b/modules/processing.py index 47712159..9cad05f2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -662,12 +662,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength self.hr_scale = hr_scale self.hr_upscaler = hr_upscaler + self.hr_second_pass_steps = hr_second_pass_steps + self.hr_resize_x = hr_resize_x + self.hr_resize_y = hr_resize_y + self.hr_upscale_to_x = hr_resize_x + self.hr_upscale_to_y = hr_resize_y if firstphase_width != 0 or firstphase_height != 0: print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) @@ -675,6 +680,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.width = firstphase_width self.height = firstphase_height + self.truncate_x = 0 + self.truncate_y = 0 + def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: if state.job_count == -1: @@ -682,7 +690,38 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - self.extra_generation_params["Hires upscale"] = self.hr_scale + if self.hr_resize_x == 0 and self.hr_resize_y == 0: + self.extra_generation_params["Hires upscale"] = self.hr_scale + self.hr_upscale_to_x = int(self.width * self.hr_scale) + self.hr_upscale_to_y = int(self.height * self.hr_scale) + else: + self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}" + + if self.hr_resize_y == 0: + self.hr_upscale_to_x = self.hr_resize_x + self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width + elif self.hr_resize_x == 0: + self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height + self.hr_upscale_to_y = self.hr_resize_y + else: + target_w = self.hr_resize_x + target_h = self.hr_resize_y + src_ratio = self.width / self.height + dst_ratio = self.hr_resize_x / self.hr_resize_y + + if src_ratio < dst_ratio: + self.hr_upscale_to_x = self.hr_resize_x + self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width + else: + self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height + self.hr_upscale_to_y = self.hr_resize_y + + self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f + self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f + + if self.hr_second_pass_steps: + self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps + if self.hr_upscaler is not None: self.extra_generation_params["Hires upscaler"] = self.hr_upscaler @@ -699,8 +738,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not self.enable_hr: return samples - target_width = int(self.width * self.hr_scale) - target_height = int(self.height * self.hr_scale) + target_width = self.hr_upscale_to_x + target_height = self.hr_upscale_to_y def save_intermediate(image, index): """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" @@ -755,13 +794,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) return samples diff --git a/modules/txt2img.py b/modules/txt2img.py index e189a899..38b5f591 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -35,6 +35,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: denoising_strength=denoising_strength if enable_hr else None, hr_scale=hr_scale, hr_upscaler=hr_upscaler, + hr_second_pass_steps=hr_second_pass_steps, + hr_resize_x=hr_resize_x, + hr_resize_y=hr_resize_y, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 44f4f3a4..04091e67 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -637,10 +637,10 @@ def create_sampler_and_steps_selection(choices, tabname): with FormRow(elem_id=f"sampler_selection_{tabname}"): sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") sampler_index.save_to_config = True - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) else: with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") return steps, sampler_index @@ -709,10 +709,16 @@ def create_ui(): enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") elif category == "hires_fix": - with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options: - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: + with FormRow(elem_id="txt2img_hires_fix_row1"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + with FormRow(elem_id="txt2img_hires_fix_row2"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") elif category == "batch": if not opts.dimensions_and_batch_together: @@ -753,6 +759,9 @@ def create_ui(): denoising_strength, hr_scale, hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, ] + custom_inputs, outputs=[ @@ -804,6 +813,9 @@ def create_ui(): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_scale, "Hires upscale"), (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), *modules.scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) -- cgit v1.2.3 From 1288a3bb7d21064e5bd0af7158a3840886027c51 Mon Sep 17 00:00:00 2001 From: Suffocate <70031311+lolsuffocate@users.noreply.github.com> Date: Wed, 4 Jan 2023 20:36:30 +0000 Subject: Use the read_info_from_image function directly --- modules/api/api.py | 16 ++++++++++++---- modules/api/models.py | 5 +++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 48a70a44..2103709b 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,10 +11,10 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack +from modules import sd_samplers, deepbooru, sd_hijack, images from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.extras import run_extras, run_pnginfo +from modules.extras import run_extras from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork @@ -233,9 +233,17 @@ class Api: if(not req.image.strip()): return PNGInfoResponse(info="") - result = run_pnginfo(decode_base64_to_image(req.image.strip())) + image = decode_base64_to_image(req.image.strip()) + if image is None: + return PNGInfoResponse(info="") + + geninfo, items = images.read_info_from_image(image) + if geninfo is None: + geninfo = "" + + items = {**{'parameters': geninfo}, **items} - return PNGInfoResponse(info=result[1]) + return PNGInfoResponse(info=geninfo, items=items) def progressapi(self, req: ProgressRequest = Depends()): # copy from check_progress_call of ui.py diff --git a/modules/api/models.py b/modules/api/models.py index 4a632c68..d8198a27 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -157,7 +157,8 @@ class PNGInfoRequest(BaseModel): image: str = Field(title="Image", description="The base64 encoded PNG image") class PNGInfoResponse(BaseModel): - info: str = Field(title="Image info", description="A string with all the info the image had") + info: str = Field(title="Image info", description="A string with the parameters used to generate the image") + items: dict = Field(title="Items", description="An object containing all the info the image had") class ProgressRequest(BaseModel): skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization") @@ -258,4 +259,4 @@ class EmbeddingItem(BaseModel): class EmbeddingsResponse(BaseModel): loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") - skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") \ No newline at end of file + skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") -- cgit v1.2.3 From bc43293c640aef65df3136de9e5bd8b7e79eb3e0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 23:56:43 +0300 Subject: fix incorrect display/calculation for number of steps for hires fix in progress bars --- modules/processing.py | 9 ++++++--- modules/sd_samplers.py | 5 +++-- modules/shared.py | 4 +++- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 9cad05f2..f28e7212 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -685,10 +685,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: - if state.job_count == -1: - state.job_count = self.n_iter * 2 - else: + if not state.processing_has_refined_job_count: + if state.job_count == -1: + state.job_count = self.n_iter + + shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count) state.job_count = state.job_count * 2 + state.processing_has_refined_job_count = True if self.hr_resize_x == 0 and self.hr_resize_y == 0: self.extra_generation_params["Hires upscale"] = self.hr_scale diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index e904d860..3851a77f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -97,8 +97,9 @@ sampler_extra_params = { def setup_img2img_steps(p, steps=None): if opts.img2img_fix_steps or steps is not None: - steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 - t_enc = p.steps - 1 + requested_steps = (steps or p.steps) + steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 + t_enc = requested_steps - 1 else: steps = p.steps t_enc = int(min(p.denoising_strength, 0.999) * steps) diff --git a/modules/shared.py b/modules/shared.py index 54a6ba23..04c545ee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -153,6 +153,7 @@ class State: job = "" job_no = 0 job_count = 0 + processing_has_refined_job_count = False job_timestamp = '0' sampling_step = 0 sampling_steps = 0 @@ -194,6 +195,7 @@ class State: def begin(self): self.sampling_step = 0 self.job_count = -1 + self.processing_has_refined_job_count = False self.job_no = 0 self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") self.current_latent = None @@ -608,7 +610,7 @@ class TotalTQDM: return if self._tqdm is None: self.reset() - self._tqdm.total=new_total + self._tqdm.total = new_total def clear(self): if self._tqdm is not None: -- cgit v1.2.3 From 5851bc839b6f639cda59e84eb1ee8c706986633d Mon Sep 17 00:00:00 2001 From: me <25877290+Kryptortio@users.noreply.github.com> Date: Wed, 4 Jan 2023 22:03:32 +0100 Subject: Add element ids for script components and a few more in ui.py --- modules/ui.py | 16 ++++++++-------- scripts/custom_code.py | 4 +++- scripts/img2imgalt.py | 22 ++++++++++++---------- scripts/loopback.py | 6 ++++-- scripts/outpainting_mk_2.py | 12 +++++++----- scripts/poor_mans_outpainting.py | 10 ++++++---- scripts/prompt_matrix.py | 6 ++++-- scripts/prompts_from_file.py | 10 ++++++---- scripts/sd_upscale.py | 8 +++++--- scripts/xy_grid.py | 15 ++++++++------- 10 files changed, 63 insertions(+), 46 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 04091e67..bb64fe20 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -560,7 +560,7 @@ Requested path was: {f} generation_info = None with gr.Column(): with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder') + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') if tabname != "extras": save = gr.Button('Save', elem_id=f'save_{tabname}') @@ -576,13 +576,13 @@ Requested path was: {f} if tabname != "extras": with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') with gr.Group(): - html_info = gr.HTML() - html_log = gr.HTML() + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') - generation_info = gr.Textbox(visible=False) + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( @@ -624,9 +624,9 @@ Requested path was: {f} ) else: - html_info_x = gr.HTML() - html_info = gr.HTML() - html_log = gr.HTML() + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log diff --git a/scripts/custom_code.py b/scripts/custom_code.py index 22e7b77a..841fed97 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -14,7 +14,9 @@ class Script(scripts.Script): return cmd_opts.allow_code def ui(self, is_img2img): - code = gr.Textbox(label="Python code", lines=1) + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_custom_code_' + + code = gr.Textbox(label="Python code", lines=1, elem_id=elem_prefix + "code") return [code] diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 1229f61b..cddd46e7 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -126,24 +126,26 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_i2i_alternative_test_' + info = gr.Markdown(''' * `CFG Scale` should be 2 or lower. ''') - override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True) + override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=elem_prefix + "override_sampler") - override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True) - original_prompt = gr.Textbox(label="Original prompt", lines=1) - original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1) + override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=elem_prefix + "override_prompt") + original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=elem_prefix + "original_prompt") + original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=elem_prefix + "original_negative_prompt") - override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True) - st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50) + override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=elem_prefix + "override_steps") + st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=elem_prefix + "st") - override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True) + override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=elem_prefix + "override_strength") - cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0) - randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0) - sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False) + cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=elem_prefix + "cfg") + randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=elem_prefix + "randomness") + sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=elem_prefix + "sigma_adjustment") return [ info, diff --git a/scripts/loopback.py b/scripts/loopback.py index d8c68af8..5c1265a0 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -17,8 +17,10 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): - loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4) - denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1) + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_loopback_' + + loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=elem_prefix + "loops") + denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=elem_prefix + "denoising_strength_change_factor") return [loops, denoising_strength_change_factor] diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index cf71cb92..760cce64 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -129,13 +129,15 @@ class Script(scripts.Script): if not is_img2img: return None + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_outpainting_mk_2_' + info = gr.HTML("

Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8

") - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8) - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) - noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0) - color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05) + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=elem_prefix + "pixels") + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=elem_prefix + "mask_blur") + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=elem_prefix + "direction") + noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=elem_prefix + "noise_q") + color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=elem_prefix + "color_variation") return [info, pixels, mask_blur, direction, noise_q, color_variation] diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index ea45beb0..6bcdcc02 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -21,10 +21,12 @@ class Script(scripts.Script): if not is_img2img: return None - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index") - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_poor_mans_outpainting_' + + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=elem_prefix + "pixels") + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=elem_prefix + "mask_blur") + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=elem_prefix + "inpainting_fill") + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=elem_prefix + "direction") return [pixels, mask_blur, inpainting_fill, direction] diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 4c79eaef..59172315 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -45,8 +45,10 @@ class Script(scripts.Script): return "Prompt matrix" def ui(self, is_img2img): - put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False) - different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False) + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_prompt_matrix_' + + put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=elem_prefix + "put_at_start") + different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=elem_prefix + "different_seeds") return [put_at_start, different_seeds] diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index e8386ed2..fc8ddd8a 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -112,11 +112,13 @@ class Script(scripts.Script): return "Prompts from file or textbox" def ui(self, is_img2img): - checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False) - checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False) + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_prompt_from_file_' + + checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=elem_prefix + "checkbox_iterate") + checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=elem_prefix + "checkbox_iterate_batch") - prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1) - file = gr.File(label="Upload prompt inputs", type='bytes') + prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=elem_prefix + "prompt_txt") + file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=elem_prefix + "file") file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 9739545c..9f483a67 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -17,10 +17,12 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_sd_upscale_' + info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") - overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64) - scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0) - upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=elem_prefix + "overlap") + scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=elem_prefix + "scale_factor") + upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=elem_prefix + "upscaler_index") return [info, overlap, upscaler_index, scale_factor] diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 78ff12c5..90226ccd 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -292,18 +292,19 @@ class Script(scripts.Script): def ui(self, is_img2img): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_xy_grid_' with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id="x_type") - x_values = gr.Textbox(label="X values", lines=1) + x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=elem_prefix + "x_type") + x_values = gr.Textbox(label="X values", lines=1, elem_id=elem_prefix + "x_values") with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id="y_type") - y_values = gr.Textbox(label="Y values", lines=1) + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=elem_prefix + "y_type") + y_values = gr.Textbox(label="Y values", lines=1, elem_id=elem_prefix + "y_values") - draw_legend = gr.Checkbox(label='Draw legend', value=True) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=False) - no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False) + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=elem_prefix + "draw_legend") + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=elem_prefix + "include_lone_images") + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=elem_prefix + "no_fixed_seeds") return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] -- cgit v1.2.3 From b663ee2cff6831354e1b5326800c8d1bf300cafe Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 00:36:10 +0300 Subject: fix fullscreen view showing wrong image on firefox --- javascript/imageviewer.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 67916536..97f56c07 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -148,7 +148,7 @@ function showGalleryImage() { if(e && e.parentElement.tagName == 'DIV'){ e.style.cursor='pointer' e.style.userSelect='none' - e.addEventListener('click', function (evt) { + e.addEventListener('mousedown', function (evt) { if(!opts.js_modal_lightbox) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) showModal(evt) -- cgit v1.2.3 From 99b67cff0b48c4a1ad6e14d9cc591b11db6e293c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 01:25:52 +0300 Subject: make hires fix not do anything if the user chooses the second pass resolution to be the same as first pass resolution --- modules/processing.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f28e7212..7e853287 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -683,16 +683,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = 0 self.truncate_y = 0 + def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: - if not state.processing_has_refined_job_count: - if state.job_count == -1: - state.job_count = self.n_iter - - shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count) - state.job_count = state.job_count * 2 - state.processing_has_refined_job_count = True - if self.hr_resize_x == 0 and self.hr_resize_y == 0: self.extra_generation_params["Hires upscale"] = self.hr_scale self.hr_upscale_to_x = int(self.width * self.hr_scale) @@ -722,6 +715,22 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f + # special case: the user has chosen to do nothing + if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height: + self.enable_hr = False + self.denoising_strength = None + self.extra_generation_params.pop("Hires upscale", None) + self.extra_generation_params.pop("Hires resize", None) + return + + if not state.processing_has_refined_job_count: + if state.job_count == -1: + state.job_count = self.n_iter + + shared.total_tqdm.updateTotal((self.steps + (self.hr_second_pass_steps or self.steps)) * state.job_count) + state.job_count = state.job_count * 2 + state.processing_has_refined_job_count = True + if self.hr_second_pass_steps: self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps -- cgit v1.2.3 From 066390eb5683945a6e094a817584ada6b1f7118e Mon Sep 17 00:00:00 2001 From: Wes Roberts Date: Wed, 4 Jan 2023 17:58:16 -0500 Subject: Fixes webui.sh to exec LAUNCH_SCRIPT --- webui.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/webui.sh b/webui.sh index 04ecbf76..c4d6521d 100755 --- a/webui.sh +++ b/webui.sh @@ -160,10 +160,10 @@ then printf "\n%s\n" "${delimiter}" printf "Accelerating launch.py..." printf "\n%s\n" "${delimiter}" - accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" + exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" else printf "\n%s\n" "${delimiter}" printf "Launching launch.py..." printf "\n%s\n" "${delimiter}" - "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" + exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" fi -- cgit v1.2.3 From 5f4fa942b8ec3ed3b15a352903489d6f9e6eb46e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 02:38:52 +0300 Subject: do not show full window image preview when right mouse button is used --- javascript/imageviewer.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 97f56c07..b7bc2fe1 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -149,7 +149,7 @@ function showGalleryImage() { e.style.cursor='pointer' e.style.userSelect='none' e.addEventListener('mousedown', function (evt) { - if(!opts.js_modal_lightbox) return; + if(!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) showModal(evt) }, true); -- cgit v1.2.3 From 2e30997450835ed8f80ab5e8f02f7d4c7f26dd3f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 10:21:17 +0300 Subject: move sd_model assignment to the place where we change the sd_model --- modules/processing.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index a12bd9e8..61e97077 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -466,12 +466,16 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet - if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model - if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE + if k == 'sd_hypernetwork': + shared.reload_hypernetworks() # make onchange call for changing hypernet + + if k == 'sd_model_checkpoint': + sd_models.reload_model_weights() # make onchange call for changing SD model + p.sd_model = shared.sd_model + + if k == 'sd_vae': + sd_vae.reload_vae_weights() # make onchange call for changing VAE - # Assign sd_model here to ensure that it reflects the model after any changes - p.sd_model = shared.sd_model res = process_images_inner(p) finally: -- cgit v1.2.3 From c3109fa18a5a105eea5e343875b540939884f304 Mon Sep 17 00:00:00 2001 From: me <25877290+Kryptortio@users.noreply.github.com> Date: Thu, 5 Jan 2023 08:27:09 +0100 Subject: Adjusted prefix from i2i/t2i to txt2img and img2img and removed those prefixes from img exclusive scripts --- scripts/custom_code.py | 2 +- scripts/img2imgalt.py | 2 +- scripts/loopback.py | 2 +- scripts/outpainting_mk_2.py | 2 +- scripts/poor_mans_outpainting.py | 2 +- scripts/prompt_matrix.py | 2 +- scripts/prompts_from_file.py | 2 +- scripts/sd_upscale.py | 2 +- scripts/xy_grid.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/custom_code.py b/scripts/custom_code.py index 841fed97..b3bbee03 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -14,7 +14,7 @@ class Script(scripts.Script): return cmd_opts.allow_code def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_custom_code_' + elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_custom_code_' code = gr.Textbox(label="Python code", lines=1, elem_id=elem_prefix + "code") diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index cddd46e7..c062dd24 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -126,7 +126,7 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_i2i_alternative_test_' + elem_prefix = 'script_i2i_alternative_test_' info = gr.Markdown(''' * `CFG Scale` should be 2 or lower. diff --git a/scripts/loopback.py b/scripts/loopback.py index 5c1265a0..93eda1eb 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -17,7 +17,7 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_loopback_' + elem_prefix = 'script_loopback_' loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=elem_prefix + "loops") denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=elem_prefix + "denoising_strength_change_factor") diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 760cce64..c37bc238 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -129,7 +129,7 @@ class Script(scripts.Script): if not is_img2img: return None - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_outpainting_mk_2_' + elem_prefix = 'script_outpainting_mk_2_' info = gr.HTML("

Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8

") diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 6bcdcc02..784ee422 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -21,7 +21,7 @@ class Script(scripts.Script): if not is_img2img: return None - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_poor_mans_outpainting_' + elem_prefix = 'script_poor_mans_outpainting_' pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=elem_prefix + "pixels") mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=elem_prefix + "mask_blur") diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 59172315..f610c334 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -45,7 +45,7 @@ class Script(scripts.Script): return "Prompt matrix" def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_prompt_matrix_' + elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_prompt_matrix_' put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=elem_prefix + "put_at_start") different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=elem_prefix + "different_seeds") diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index fc8ddd8a..c6a0b709 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -112,7 +112,7 @@ class Script(scripts.Script): return "Prompts from file or textbox" def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_prompt_from_file_' + elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_prompt_from_file_' checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=elem_prefix + "checkbox_iterate") checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=elem_prefix + "checkbox_iterate_batch") diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 9f483a67..2aeeb106 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -17,7 +17,7 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_sd_upscale_' + elem_prefix = 'script_sd_upscale_' info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=elem_prefix + "overlap") diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 90226ccd..8c9cfb9b 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -292,7 +292,7 @@ class Script(scripts.Script): def ui(self, is_img2img): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_xy_grid_' + elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_xy_grid_' with gr.Row(): x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=elem_prefix + "x_type") -- cgit v1.2.3 From 42fcc79bd31e5e5485f1cf115ad505cc623d0ac9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 10:43:21 +0300 Subject: add Discard penultimate sigma to infotext --- modules/sd_samplers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 31b255a3..01221b89 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -463,8 +463,12 @@ class KDiffusionSampler: return extra_params_kwargs def get_sigmas(self, p, steps): - disc = opts.always_discard_next_to_last_sigma or (self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)) - steps += 1 if disc else 0 + discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) + if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: + discard_next_to_last_sigma = True + p.extra_generation_params["Discard penultimate sigma"] = True + + steps += 1 if discard_next_to_last_sigma else 0 if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) @@ -475,7 +479,7 @@ class KDiffusionSampler: else: sigmas = self.model_wrap.get_sigmas(steps) - if disc: + if discard_next_to_last_sigma: sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) return sigmas -- cgit v1.2.3 From f185baeb28f348e4ec97cd7070ed219b5f74a48e Mon Sep 17 00:00:00 2001 From: me <25877290+Kryptortio@users.noreply.github.com> Date: Thu, 5 Jan 2023 09:29:07 +0100 Subject: Refactor elem_prefix as function elem_id --- scripts/custom_code.py | 9 ++++++--- scripts/img2imgalt.py | 30 +++++++++++++++++------------- scripts/loopback.py | 15 ++++++++++----- scripts/outpainting_mk_2.py | 18 +++++++++++------- scripts/poor_mans_outpainting.py | 17 ++++++++++------- scripts/prompt_matrix.py | 14 +++++++++----- scripts/prompts_from_file.py | 18 +++++++++++------- scripts/sd_upscale.py | 16 ++++++++++------ scripts/xy_grid.py | 20 ++++++++++++-------- 9 files changed, 96 insertions(+), 61 deletions(-) diff --git a/scripts/custom_code.py b/scripts/custom_code.py index b3bbee03..9ce1f650 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -3,20 +3,23 @@ import gradio as gr from modules.processing import Processed from modules.shared import opts, cmd_opts, state +import re class Script(scripts.Script): def title(self): return "Custom code" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id def show(self, is_img2img): return cmd_opts.allow_code def ui(self, is_img2img): - elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_custom_code_' - - code = gr.Textbox(label="Python code", lines=1, elem_id=elem_prefix + "code") + code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code")) return [code] diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index c062dd24..7555e874 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -16,6 +16,7 @@ import k_diffusion as K from PIL import Image from torch import autocast from einops import rearrange, repeat +import re def find_noise_for_image(p, cond, uncond, cfg_scale, steps): @@ -122,30 +123,33 @@ class Script(scripts.Script): def title(self): return "img2img alternative test" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def show(self, is_img2img): return is_img2img - def ui(self, is_img2img): - elem_prefix = 'script_i2i_alternative_test_' - + def ui(self, is_img2img): info = gr.Markdown(''' * `CFG Scale` should be 2 or lower. ''') - override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=elem_prefix + "override_sampler") + override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler")) - override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=elem_prefix + "override_prompt") - original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=elem_prefix + "original_prompt") - original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=elem_prefix + "original_negative_prompt") + override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt")) + original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt")) + original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt")) - override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=elem_prefix + "override_steps") - st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=elem_prefix + "st") + override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps")) + st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st")) - override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=elem_prefix + "override_strength") + override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength")) - cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=elem_prefix + "cfg") - randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=elem_prefix + "randomness") - sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=elem_prefix + "sigma_adjustment") + cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg")) + randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness")) + sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment")) return [ info, diff --git a/scripts/loopback.py b/scripts/loopback.py index 93eda1eb..4df7b73f 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -8,19 +8,24 @@ from modules import processing, shared, sd_samplers, images from modules.processing import Processed from modules.sd_samplers import samplers from modules.shared import opts, cmd_opts, state +import re + class Script(scripts.Script): def title(self): return "Loopback" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def show(self, is_img2img): return is_img2img - def ui(self, is_img2img): - elem_prefix = 'script_loopback_' - - loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=elem_prefix + "loops") - denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=elem_prefix + "denoising_strength_change_factor") + def ui(self, is_img2img): + loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) + denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor")) return [loops, denoising_strength_change_factor] diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index c37bc238..b4a0dc73 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,6 +10,7 @@ from PIL import Image, ImageDraw from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state +import re # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -122,6 +123,11 @@ class Script(scripts.Script): def title(self): return "Outpainting mk2" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def show(self, is_img2img): return is_img2img @@ -129,15 +135,13 @@ class Script(scripts.Script): if not is_img2img: return None - elem_prefix = 'script_outpainting_mk_2_' - info = gr.HTML("

Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8

") - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=elem_prefix + "pixels") - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=elem_prefix + "mask_blur") - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=elem_prefix + "direction") - noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=elem_prefix + "noise_q") - color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=elem_prefix + "color_variation") + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) + noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) + color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) return [info, pixels, mask_blur, direction, noise_q, color_variation] diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 784ee422..1c7dc467 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,26 +7,29 @@ from PIL import Image, ImageDraw from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state - +import re class Script(scripts.Script): def title(self): return "Poor man's outpainting" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def show(self, is_img2img): return is_img2img def ui(self, is_img2img): if not is_img2img: return None - - elem_prefix = 'script_poor_mans_outpainting_' - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=elem_prefix + "pixels") - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=elem_prefix + "mask_blur") - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=elem_prefix + "inpainting_fill") - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=elem_prefix + "direction") + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) return [pixels, mask_blur, inpainting_fill, direction] diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index f610c334..278d2e68 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -10,6 +10,7 @@ from modules import images from modules.processing import process_images, Processed from modules.shared import opts, cmd_opts, state import modules.sd_samplers +import re def draw_xy_grid(xs, ys, x_label, y_label, cell): @@ -44,11 +45,14 @@ class Script(scripts.Script): def title(self): return "Prompt matrix" - def ui(self, is_img2img): - elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_prompt_matrix_' - - put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=elem_prefix + "put_at_start") - different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=elem_prefix + "different_seeds") + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + + def ui(self, is_img2img): + put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) + different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) return [put_at_start, different_seeds] diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index c6a0b709..5c84c3e9 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -13,6 +13,7 @@ from modules import sd_samplers from modules.processing import Processed, process_images from PIL import Image from modules.shared import opts, cmd_opts, state +import re def process_string_tag(tag): @@ -111,14 +112,17 @@ class Script(scripts.Script): def title(self): return "Prompts from file or textbox" - def ui(self, is_img2img): - elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_prompt_from_file_' - - checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=elem_prefix + "checkbox_iterate") - checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=elem_prefix + "checkbox_iterate_batch") + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id - prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=elem_prefix + "prompt_txt") - file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=elem_prefix + "file") + def ui(self, is_img2img): + checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) + checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) + + prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) + file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=self.elem_id("file")) file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 2aeeb106..247e755b 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -7,22 +7,26 @@ from PIL import Image from modules import processing, shared, sd_samplers, images, devices from modules.processing import Processed from modules.shared import opts, cmd_opts, state +import re class Script(scripts.Script): def title(self): return "SD upscale" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def show(self, is_img2img): return is_img2img - def ui(self, is_img2img): - elem_prefix = 'script_sd_upscale_' - + def ui(self, is_img2img): info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") - overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=elem_prefix + "overlap") - scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=elem_prefix + "scale_factor") - upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=elem_prefix + "upscaler_index") + overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) + scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) + upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index")) return [info, overlap, upscaler_index, scale_factor] diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8c9cfb9b..b277a439 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -290,21 +290,25 @@ class Script(scripts.Script): def title(self): return "X/Y plot" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def ui(self, is_img2img): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] - elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_xy_grid_' with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=elem_prefix + "x_type") - x_values = gr.Textbox(label="X values", lines=1, elem_id=elem_prefix + "x_values") + x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=elem_prefix + "y_type") - y_values = gr.Textbox(label="Y values", lines=1, elem_id=elem_prefix + "y_values") + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=elem_prefix + "draw_legend") - include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=elem_prefix + "include_lone_images") - no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=elem_prefix + "no_fixed_seeds") + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] -- cgit v1.2.3 From 997461d3dd86f51c06ea0c2eff17ce8b8b48c0af Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 11:57:01 +0300 Subject: add footer with versions --- html/footer.html | 4 ++++ launch.py | 20 ++++++++++++++++---- modules/ui.py | 31 ++++++++++++++++++++++++++++++- style.css | 5 +++++ 4 files changed, 55 insertions(+), 5 deletions(-) diff --git a/html/footer.html b/html/footer.html index a8f2adf7..bad87ff6 100644 --- a/html/footer.html +++ b/html/footer.html @@ -7,3 +7,7 @@  •  Reload UI +
+
+{versions} +
diff --git a/launch.py b/launch.py index af0d418b..49b91b1f 100644 --- a/launch.py +++ b/launch.py @@ -13,6 +13,21 @@ dir_extensions = "extensions" python = sys.executable git = os.environ.get('GIT', "git") index_url = os.environ.get('INDEX_URL', "") +stored_commit_hash = None + + +def commit_hash(): + global stored_commit_hash + + if stored_commit_hash is not None: + return stored_commit_hash + + try: + stored_commit_hash = run(f"{git} rev-parse HEAD").strip() + except Exception: + stored_commit_hash = "" + + return stored_commit_hash def extract_arg(args, name): @@ -194,10 +209,7 @@ def prepare_environment(): xformers = '--xformers' in sys.argv ngrok = '--ngrok' in sys.argv - try: - commit = run(f"{git} rev-parse HEAD").strip() - except Exception: - commit = "" + commit = commit_hash() print(f"Python {sys.version}") print(f"Commit hash: {commit}") diff --git a/modules/ui.py b/modules/ui.py index bb64fe20..81d96c5b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1696,7 +1696,9 @@ def create_ui(): if os.path.exists("html/footer.html"): with open("html/footer.html", encoding="utf8") as file: - gr.HTML(file.read(), elem_id="footer") + footer = file.read() + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) settings_submit.click( @@ -1857,3 +1859,30 @@ def reload_javascript(): if not hasattr(shared, 'GradioTemplateResponseOriginal'): shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + + +def versions_html(): + import torch + import launch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = launch.commit_hash() + short_commit = commit[0:8] + + if shared.xformers_available: + import xformers + xformers_version = xformers.__version__ + else: + xformers_version = "N/A" + + return f""" +python: {python_version} + •  +torch: {torch.__version__} + •  +xformers: {xformers_version} + •  +gradio: {gr.__version__} + •  +commit: {short_commit} +""" diff --git a/style.css b/style.css index 09ee540b..ee74d79e 100644 --- a/style.css +++ b/style.css @@ -628,6 +628,11 @@ footer { display: inline-block; } +#footer .versions{ + font-size: 85%; + opacity: 0.85; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From f8d0cf6a6ec4911559cfecb9a9d1d46b547b38e8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 12:08:11 +0300 Subject: rework #6329 to remove duplicate code and add prevent tab names for showing in ids for scripts that only exist on one tab --- modules/scripts.py | 10 ++++++++++ scripts/custom_code.py | 6 ------ scripts/img2imgalt.py | 6 ------ scripts/loopback.py | 6 ------ scripts/outpainting_mk_2.py | 6 ------ scripts/poor_mans_outpainting.py | 6 ------ scripts/prompt_matrix.py | 6 ------ scripts/prompts_from_file.py | 6 ------ scripts/sd_upscale.py | 6 ------ scripts/xy_grid.py | 5 ----- 10 files changed, 10 insertions(+), 53 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 722f8685..0c44f191 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,4 +1,5 @@ import os +import re import sys import traceback from collections import namedtuple @@ -128,6 +129,15 @@ class Script: """unused""" return "" + def elem_id(self, item_id): + """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id""" + + need_tabname = self.show(True) == self.show(False) + tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else "" + title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower())) + + return f'script_{tabname}{title}_{item_id}' + current_basedir = paths.script_path diff --git a/scripts/custom_code.py b/scripts/custom_code.py index 9ce1f650..d29113e6 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -3,18 +3,12 @@ import gradio as gr from modules.processing import Processed from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "Custom code" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return cmd_opts.allow_code diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 7555e874..cbdfc6b3 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -16,7 +16,6 @@ import k_diffusion as K from PIL import Image from torch import autocast from einops import rearrange, repeat -import re def find_noise_for_image(p, cond, uncond, cfg_scale, steps): @@ -123,11 +122,6 @@ class Script(scripts.Script): def title(self): return "img2img alternative test" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/loopback.py b/scripts/loopback.py index 4df7b73f..1dab9476 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -8,18 +8,12 @@ from modules import processing, shared, sd_samplers, images from modules.processing import Processed from modules.sd_samplers import samplers from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "Loopback" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index b4a0dc73..0906da6a 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,7 +10,6 @@ from PIL import Image, ImageDraw from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state -import re # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -123,11 +122,6 @@ class Script(scripts.Script): def title(self): return "Outpainting mk2" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 1c7dc467..d8feda00 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,18 +7,12 @@ from PIL import Image, ImageDraw from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "Poor man's outpainting" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 278d2e68..dd95e588 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -10,7 +10,6 @@ from modules import images from modules.processing import process_images, Processed from modules.shared import opts, cmd_opts, state import modules.sd_samplers -import re def draw_xy_grid(xs, ys, x_label, y_label, cell): @@ -45,11 +44,6 @@ class Script(scripts.Script): def title(self): return "Prompt matrix" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def ui(self, is_img2img): put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 5c84c3e9..2751f98a 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -13,7 +13,6 @@ from modules import sd_samplers from modules.processing import Processed, process_images from PIL import Image from modules.shared import opts, cmd_opts, state -import re def process_string_tag(tag): @@ -112,11 +111,6 @@ class Script(scripts.Script): def title(self): return "Prompts from file or textbox" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def ui(self, is_img2img): checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 247e755b..9b8ffd85 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -7,18 +7,12 @@ from PIL import Image from modules import processing, shared, sd_samplers, images, devices from modules.processing import Processed from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "SD upscale" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index b277a439..f04d9b7e 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -290,11 +290,6 @@ class Script(scripts.Script): def title(self): return "X/Y plot" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def ui(self, is_img2img): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] -- cgit v1.2.3 From eea8fc40e16664ddc8a9aec77206da704a35dde0 Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 07:24:22 -0800 Subject: Add option to save ti settings to file. --- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 30 +++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index e0f44c6d..933cd738 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,6 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), + "save_train_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file when training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 71e07bcc..2bed2ecb 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,6 +1,7 @@ import os import sys import traceback +import inspect import torch import tqdm @@ -229,6 +230,28 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + checkpoint = sd_models.select_checkpoint() + model_name = checkpoint.model_name + model_hash = '[{}]'.format(checkpoint.hash) + + # Get a list of the argument names. + arg_names = inspect.getfullargspec(save_settings_to_file).args + + # Create a list of the argument names to include in the settings string. + names = arg_names[:16] # Include all arguments up until the preview-related ones. + if preview_from_txt2img: + names.extend(arg_names[16:]) # Include all remaining arguments if `preview_from_txt2img` is True. + + # Build the settings string. + settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" + for name in names: + value = locals()[name] + settings_str += f"{name}: {value}\n" + + with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: + fout.write(settings_str + "\n\n") + def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" @@ -292,13 +315,13 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ if initial_step >= steps: shared.state.textinfo = "Model has already been trained beyond specified max steps" return embedding, filename + scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ None if clip_grad: - clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed @@ -306,7 +329,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - + if shared.opts.save_train_settings_to_txt: + save_settings_to_file(initial_step , len(ds) , embedding_name, len(embedding.vec) , learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) -- cgit v1.2.3 From 19a81ac2871ec900fc8b7955bbc2554b6c5ac6b1 Mon Sep 17 00:00:00 2001 From: cat Date: Thu, 5 Jan 2023 20:17:39 +0500 Subject: hires-fix: add "nearest-exact" latent upscale mode. --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared.py b/modules/shared.py index e0f44c6d..b7a3ce5c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -576,6 +576,7 @@ latent_upscale_modes = { "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True}, "Latent (nearest)": {"mode": "nearest", "antialias": False}, + "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False}, } sd_upscalers = [] -- cgit v1.2.3 From b85c2b5cf4a6809bc871718cf4680d49c3e95e94 Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 08:14:38 -0800 Subject: Clean up ti, add same behavior to hypernetwork. --- modules/hypernetworks/hypernetwork.py | 31 +++++++++++++++++++++++++- modules/shared.py | 2 +- modules/textual_inversion/textual_inversion.py | 14 +++++++----- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6a9b1398..d5985263 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -401,7 +401,33 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, hypernet.save(fn) shared.reload_hypernetworks() +# Note: textual_inversion.py has a nearly identical function of the same name. +def save_settings_to_file(initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + checkpoint = sd_models.select_checkpoint() + model_name = checkpoint.model_name + model_hash = '[{}]'.format(checkpoint.hash) + # Starting index of preview-related arguments. + border_index = 19 + + # Get a list of the argument names, excluding default argument. + sig = inspect.signature(save_settings_to_file) + arg_names = [p.name for p in sig.parameters.values() if p.default == p.empty] + + # Create a list of the argument names to include in the settings string. + names = arg_names[:border_index] # Include all arguments up until the preview-related ones. + + # Include preview-related arguments if applicable. + if preview_from_txt2img: + names.extend(arg_names[border_index:]) + + # Build the settings string. + settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" + for name in names: + value = locals()[name] + settings_str += f"{name}: {value}\n" + with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: + fout.write(settings_str + "\n\n") def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. @@ -457,7 +483,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - + + if shared.opts.save_training_settings_to_txt: + save_settings_to_file(initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) diff --git a/modules/shared.py b/modules/shared.py index 933cd738..10231a75 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,7 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), - "save_train_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file when training starts."), + "save_training_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file whenever training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 2bed2ecb..68648550 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -230,18 +230,20 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) +# Note: hypernetwork.py has a nearly identical function of the same name. def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): checkpoint = sd_models.select_checkpoint() model_name = checkpoint.model_name model_hash = '[{}]'.format(checkpoint.hash) - + # Starting index of preview-related arguments. + border_index = 16 # Get a list of the argument names. arg_names = inspect.getfullargspec(save_settings_to_file).args # Create a list of the argument names to include in the settings string. - names = arg_names[:16] # Include all arguments up until the preview-related ones. + names = arg_names[:border_index] # Include all arguments up until the preview-related ones. if preview_from_txt2img: - names.extend(arg_names[16:]) # Include all remaining arguments if `preview_from_txt2img` is True. + names.extend(arg_names[border_index:]) # Include all remaining arguments if `preview_from_txt2img` is True. # Build the settings string. settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" @@ -329,8 +331,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) - if shared.opts.save_train_settings_to_txt: - save_settings_to_file(initial_step , len(ds) , embedding_name, len(embedding.vec) , learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + + if shared.opts.save_training_settings_to_txt: + save_settings_to_file(initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) -- cgit v1.2.3 From b6bab2f052b32c0ffebe6aecc1819ccf20cf8c5d Mon Sep 17 00:00:00 2001 From: timntorres Date: Thu, 5 Jan 2023 09:14:56 -0800 Subject: Include model in log file. Exclude directory. --- modules/hypernetworks/hypernetwork.py | 28 +++++++++----------------- modules/textual_inversion/textual_inversion.py | 22 +++++++++----------- 2 files changed, 19 insertions(+), 31 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index d5985263..3237c37a 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -402,30 +402,22 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() # Note: textual_inversion.py has a nearly identical function of the same name. -def save_settings_to_file(initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - checkpoint = sd_models.select_checkpoint() - model_name = checkpoint.model_name - model_hash = '[{}]'.format(checkpoint.hash) +def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # Starting index of preview-related arguments. - border_index = 19 - - # Get a list of the argument names, excluding default argument. - sig = inspect.signature(save_settings_to_file) - arg_names = [p.name for p in sig.parameters.values() if p.default == p.empty] - + border_index = 21 + # Get a list of the argument names. + arg_names = inspect.getfullargspec(save_settings_to_file).args # Create a list of the argument names to include in the settings string. names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - - # Include preview-related arguments if applicable. if preview_from_txt2img: - names.extend(arg_names[border_index:]) - + names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. # Build the settings string. settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" for name in names: - value = locals()[name] - settings_str += f"{name}: {value}\n" - + if name != 'log_directory': # It's useless and redundant to save log_directory. + value = locals()[name] + settings_str += f"{name}: {value}\n" + # Create or append to the file. with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: fout.write(settings_str + "\n\n") @@ -485,7 +477,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) latent_sampling_method = ds.latent_sampling_method diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 68648550..ce7e4f5d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -231,26 +231,22 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) # Note: hypernetwork.py has a nearly identical function of the same name. -def save_settings_to_file(initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - checkpoint = sd_models.select_checkpoint() - model_name = checkpoint.model_name - model_hash = '[{}]'.format(checkpoint.hash) +def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # Starting index of preview-related arguments. - border_index = 16 + border_index = 18 # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - + arg_names = inspect.getfullargspec(save_settings_to_file).args # Create a list of the argument names to include in the settings string. names = arg_names[:border_index] # Include all arguments up until the preview-related ones. if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include all remaining arguments if `preview_from_txt2img` is True. - + names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. # Build the settings string. settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" for name in names: - value = locals()[name] - settings_str += f"{name}: {value}\n" - + if name != 'log_directory': # It's useless and redundant to save log_directory. + value = locals()[name] + settings_str += f"{name}: {value}\n" + # Create or append to the file. with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: fout.write(settings_str + "\n\n") @@ -333,7 +329,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) latent_sampling_method = ds.latent_sampling_method -- cgit v1.2.3 From fda04e620d529031e2134520e74756d0efa30464 Mon Sep 17 00:00:00 2001 From: Kuma <36082288+KumiIT@users.noreply.github.com> Date: Thu, 5 Jan 2023 18:44:19 +0100 Subject: typo in TI --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 71e07bcc..24b43045 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -298,7 +298,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ None if clip_grad: - clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False) + clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed -- cgit v1.2.3 From 847f869c67c7108e3e792fc193331d0e6acca29c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 21:00:52 +0300 Subject: experimental optimization --- modules/processing.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 61e97077..a408d622 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -544,6 +544,29 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = [] output_images = [] + cached_uc = [None, None] + cached_c = [None, None] + + def get_conds_with_caching(function, required_prompts, steps, cache): + """ + Returns the result of calling function(shared.sd_model, required_prompts, steps) + using a cache to store the result if the same arguments have been used before. + + cache is an array containing two elements. The first element is a tuple + representing the previously used arguments, or None if no arguments + have been used before. The second element is where the previously + computed result is stored. + """ + + if cache[0] is not None and (required_prompts, steps) == cache[0]: + return cache[1] + + with devices.autocast(): + cache[1] = function(shared.sd_model, required_prompts, steps) + + cache[0] = (required_prompts, steps) + return cache[1] + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) @@ -571,9 +594,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) - with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(shared.sd_model, negative_prompts, p.steps) - c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) + uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc) + c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c) if len(model_hijack.comments) > 0: for comment in model_hijack.comments: -- cgit v1.2.3 From 81133d4168ae0bae9bf8bf1a1d4983319a589112 Mon Sep 17 00:00:00 2001 From: Faber Date: Fri, 6 Jan 2023 03:38:37 +0700 Subject: allow loading embeddings from subdirectories --- modules/textual_inversion/textual_inversion.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 24b43045..0a059044 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -149,19 +149,20 @@ class EmbeddingDatabase: else: self.skipped_embeddings[name] = embedding - for fn in os.listdir(self.embeddings_dir): - try: - fullfn = os.path.join(self.embeddings_dir, fn) - - if os.stat(fullfn).st_size == 0: + for root, dirs, fns in os.walk(self.embeddings_dir): + for fn in fns: + try: + fullfn = os.path.join(root, fn) + + if os.stat(fullfn).st_size == 0: + continue + + process_file(fullfn, fn) + except Exception: + print(f"Error loading embedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) continue - process_file(fullfn, fn) - except Exception: - print(f"Error loading embedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue - print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") -- cgit v1.2.3 From b5253f0dab529707f1fe2e11211a10ce2f264617 Mon Sep 17 00:00:00 2001 From: noodleanon <122053346+noodleanon@users.noreply.github.com> Date: Thu, 5 Jan 2023 21:21:48 +0000 Subject: allow img2img api to run scripts --- modules/api/api.py | 27 ++++++++++++++++++++++++--- modules/api/models.py | 2 +- modules/processing.py | 4 ++-- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 2103709b..aa62a42e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,7 +11,7 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.extras import run_extras @@ -28,8 +28,13 @@ def upscaler_to_index(name: str): try: return [x.name.lower() for x in shared.sd_upscalers].index(name.lower()) except: - raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") + raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be one of these: {' , '.join([x.name for x in sd_upscalers])}") +def script_name_to_index(name, scripts): + try: + return [script.title().lower() for script in scripts].index(name.lower()) + except: + raise HTTPException(status_code=422, detail=f"Script '{name}' not found") def validate_sampler_name(name): config = sd_samplers.all_samplers_map.get(name, None) @@ -170,6 +175,14 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") + if img2imgreq.script_name is not None: + if scripts.scripts_img2img.scripts == []: + scripts.scripts_img2img.initialize_scripts(True) + ui.create_ui() + + script_idx = script_name_to_index(img2imgreq.script_name, scripts.scripts_img2img.selectable_scripts) + script = scripts.scripts_img2img.selectable_scripts[script_idx] + mask = img2imgreq.mask if mask: mask = decode_base64_to_image(mask) @@ -186,13 +199,21 @@ class Api: args = vars(populate) args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. + args.pop('script_name', None) with self.queue_lock: p = StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args) p.init_images = [decode_base64_to_image(x) for x in init_images] shared.state.begin() - processed = process_images(p) + if 'script' in locals(): + p.outpath_grids = opts.outdir_img2img_grids + p.outpath_samples = opts.outdir_img2img_samples + p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) + else: + processed = process_images(p) + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/api/models.py b/modules/api/models.py index d8198a27..862477e7 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -106,7 +106,7 @@ StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingImg2Img", StableDiffusionProcessingImg2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] ).generate_model() class TextToImageResponse(BaseModel): diff --git a/modules/processing.py b/modules/processing.py index a408d622..d5ac7eb1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -98,7 +98,7 @@ class StableDiffusionProcessing(): """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing """ - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -149,7 +149,7 @@ class StableDiffusionProcessing(): self.seed_resize_from_w = 0 self.scripts = None - self.script_args = None + self.script_args = script_args self.all_prompts = None self.all_negative_prompts = None self.all_seeds = None -- cgit v1.2.3 From eadd1bf06adbd7263875640a6446d3b0184d1561 Mon Sep 17 00:00:00 2001 From: noodleanon <122053346+noodleanon@users.noreply.github.com> Date: Thu, 5 Jan 2023 21:22:04 +0000 Subject: allow sdupscale to accept upscaler name --- scripts/sd_upscale.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 9b8ffd85..332d76d9 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -25,6 +25,8 @@ class Script(scripts.Script): return [info, overlap, upscaler_index, scale_factor] def run(self, p, _, overlap, upscaler_index, scale_factor): + if isinstance(upscaler_index, str): + upscaler_index = [x.name.lower() for x in shared.sd_upscalers].index(upscaler_index.lower()) processing.fix_seed(p) upscaler = shared.sd_upscalers[upscaler_index] -- cgit v1.2.3 From 8111b5569d07c7ac3b695e28171aede728b4ae56 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 3 Jan 2023 20:43:05 -0500 Subject: Add support for PyTorch nightly and local builds --- modules/devices.py | 28 +++++++++++++++++++++++----- webui.py | 7 ++++++- 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 800510b7..caeb0276 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs): return orig_tensor_numpy(self, *args, **kwargs) -# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working -if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): - torch.Tensor.to = tensor_to_fix - torch.nn.functional.layer_norm = layer_norm_fix - torch.Tensor.numpy = numpy_fix +# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +orig_cumsum = torch.cumsum +orig_Tensor_cumsum = torch.Tensor.cumsum +def cumsum_fix(input, cumsum_func, *args, **kwargs): + if input.device.type == 'mps': + output_dtype = kwargs.get('dtype', input.dtype) + if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]): + return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + return cumsum_func(input, *args, **kwargs) + + +if has_mps(): + if version.parse(torch.__version__) < version.parse("1.13"): + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix + elif version.parse(torch.__version__) > version.parse("1.13.1"): + if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)): + torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + orig_narrow = torch.narrow + torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) diff --git a/webui.py b/webui.py index 13375e71..ddfaea95 100644 --- a/webui.py +++ b/webui.py @@ -4,7 +4,7 @@ import threading import time import importlib import signal -import threading +import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -13,6 +13,11 @@ from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path +import torch +# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors +if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0) + from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer import modules.extras -- cgit v1.2.3 From d61a5aa4f623f6630670241aca8fc5c2a6381769 Mon Sep 17 00:00:00 2001 From: acncagua Date: Fri, 6 Jan 2023 10:58:22 +0900 Subject: Add files via upload --- modules/ui.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index 81d96c5b..030f0685 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -550,6 +550,8 @@ Requested path was: {f} os.startfile(path) elif platform.system() == "Darwin": sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) else: sp.Popen(["xdg-open", path]) -- cgit v1.2.3 From d782a95967c9eea753df3333cd1954b6ec73eba0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 27 Dec 2022 08:50:55 -0500 Subject: Add Birch-san's sub-quadratic attention implementation --- README.md | 1 + modules/sd_hijack.py | 15 ++- modules/sd_hijack_optimizations.py | 124 ++++++++++++++++++----- modules/shared.py | 4 + modules/sub_quadratic_attention.py | 201 +++++++++++++++++++++++++++++++++++++ requirements.txt | 2 +- 6 files changed, 312 insertions(+), 35 deletions(-) create mode 100644 modules/sub_quadratic_attention.py diff --git a/README.md b/README.md index 556000fb..1913caf3 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ The documentation was moved from this README over to the project's [wiki](https: - Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) +- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san), Amin Rezaei (https://github.com/AminRezaei0x443) - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 690a9ec2..019a6f3f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -7,8 +7,6 @@ from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet -from modules.sd_hijack_optimizations import invokeAI_mps_available - import ldm.modules.attention import ldm.modules.diffusionmodules.model import ldm.modules.diffusionmodules.openaimodel @@ -40,17 +38,16 @@ def apply_optimizations(): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + elif cmd_opts.opt_sub_quad_attention: + print("Applying sub-quadratic cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): - if not invokeAI_mps_available and shared.device.type == 'mps': - print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") - print("Applying v1 cross attention optimization.") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - else: - print("Applying cross attention optimization (InvokeAI).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI + print("Applying cross attention optimization (InvokeAI).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 02c87f40..f5c153e8 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,7 @@ import math import sys import traceback -import importlib +import psutil import torch from torch import einsum @@ -12,6 +12,8 @@ from einops import rearrange from modules import shared from modules.hypernetworks import hypernetwork +from .sub_quadratic_attention import efficient_dot_product_attention + if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: try: @@ -22,6 +24,19 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: print(traceback.format_exc(), file=sys.stderr) +def get_available_vram(): + if shared.device.type == 'cuda': + stats = torch.cuda.memory_stats(shared.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total + else: + return psutil.virtual_memory().available + + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): h = self.heads @@ -76,12 +91,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() @@ -118,19 +128,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def check_for_psutil(): - try: - spec = importlib.util.find_spec('psutil') - return spec is not None - except ModuleNotFoundError: - return False - -invokeAI_mps_available = check_for_psutil() - # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- -if invokeAI_mps_available: - import psutil - mem_total_gb = psutil.virtual_memory().total // (1 << 30) +mem_total_gb = psutil.virtual_memory().total // (1 << 30) def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) @@ -215,6 +214,70 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # -- End of code from https://github.com/invoke-ai/InvokeAI -- + +# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +def sub_quad_attention_forward(self, x, context=None, mask=None): + assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." + + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + k = self.to_k(context_k) + v = self.to_v(context_v) + del context, context_k, context_v, x + + q = q.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) + + out_proj, dropout = self.to_out + x = out_proj(x) + x = dropout(x) + + return x + +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): + bytes_per_token = torch.finfo(q.dtype).bits//8 + batch_x_heads, q_tokens, _ = q.shape + _, k_tokens, _ = k.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + + available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + + if chunk_threshold_bytes is None: + chunk_threshold_bytes = available_vram + elif chunk_threshold_bytes == 0: + chunk_threshold_bytes = None + + if kv_chunk_size_min is None: + kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) + elif kv_chunk_size_min == 0: + kv_chunk_size_min = None + + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) + + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -252,12 +315,7 @@ def cross_attention_attnblock_forward(self, x): h_ = torch.zeros_like(k, device=q.device) - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + mem_free_total = get_available_vram() tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() mem_required = tensor_size * 2.5 @@ -312,3 +370,19 @@ def xformers_attnblock_forward(self, x): return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) + +def sub_quad_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..487a7792 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -56,6 +56,10 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") +parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") +parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) +parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) +parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py new file mode 100644 index 00000000..b11dc1c7 --- /dev/null +++ b/modules/sub_quadratic_attention.py @@ -0,0 +1,201 @@ +# original source: +# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py +# license: +# unspecified +# credit: +# Amin Rezaei (original author) +# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# implementation of: +# Self-attention Does Not Need O(n2) Memory": +# https://arxiv.org/abs/2112.05682v2 + +from functools import partial +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint +import math +from typing import Optional, NamedTuple, Protocol, List + +def dynamic_slice( + x: Tensor, + starts: List[int], + sizes: List[int], +) -> Tensor: + slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] + return x[slicing] + +class AttnChunk(NamedTuple): + exp_values: Tensor + exp_weights_sum: Tensor + max_score: Tensor + +class SummarizeChunk(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> AttnChunk: ... + +class ComputeQueryChunkAttn(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> Tensor: ... + +def _summarize_chunk( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> AttnChunk: + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + max_score, _ = torch.max(attn_weights, -1, keepdim=True) + max_score = max_score.detach() + exp_weights = torch.exp(attn_weights - max_score) + exp_values = torch.bmm(exp_weights, value) + max_score = max_score.squeeze(-1) + return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + +def _query_chunk_attention( + query: Tensor, + key: Tensor, + value: Tensor, + summarize_chunk: SummarizeChunk, + kv_chunk_size: int, +) -> Tensor: + batch_x_heads, k_tokens, k_channels_per_head = key.shape + _, _, v_channels_per_head = value.shape + + def chunk_scanner(chunk_idx: int) -> AttnChunk: + key_chunk = dynamic_slice( + key, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, k_channels_per_head) + ) + value_chunk = dynamic_slice( + value, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, v_channels_per_head) + ) + return summarize_chunk(query, key_chunk, value_chunk) + + chunks: List[AttnChunk] = [ + chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + ] + acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) + chunk_values, chunk_weights, chunk_max = acc_chunk + + global_max, _ = torch.max(chunk_max, 0, keepdim=True) + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= torch.unsqueeze(max_diffs, -1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(dim=0) + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) + return all_values / all_weights + +# TODO: refactor CrossAttention#get_attention_scores to share code with this +def _get_attention_scores_no_kv_chunking( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> Tensor: + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + attn_probs = attn_scores.softmax(dim=-1) + del attn_scores + hidden_states_slice = torch.bmm(attn_probs, value) + return hidden_states_slice + +class ScannedChunk(NamedTuple): + chunk_idx: int + attn_chunk: AttnChunk + +def efficient_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + query_chunk_size=1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + use_checkpoint=True, +): + """Computes efficient dot-product attention given query, key, and value. + This is efficient version of attention presented in + https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. + Args: + query: queries for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + key: keys for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + value: values to be used in attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + query_chunk_size: int: query chunks size + kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) + kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) + Returns: + Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. + """ + batch_x_heads, q_tokens, q_channels_per_head = query.shape + _, k_tokens, _ = key.shape + scale = q_channels_per_head ** -0.5 + + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + + def get_query_chunk(chunk_idx: int) -> Tensor: + return dynamic_slice( + query, + (0, chunk_idx, 0), + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + ) + + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) + summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk + compute_query_chunk_attn: ComputeQueryChunkAttn = partial( + _get_attention_scores_no_kv_chunking, + scale=scale + ) if k_tokens <= kv_chunk_size else ( + # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) + partial( + _query_chunk_attention, + kv_chunk_size=kv_chunk_size, + summarize_chunk=summarize_chunk, + ) + ) + + if q_tokens <= query_chunk_size: + # fast-path for when there's just 1 query chunk + return compute_query_chunk_attn( + query=query, + key=key, + value=value, + ) + + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, + # and pass slices to be mutated, instead of torch.cat()ing the returned slices + res = torch.cat([ + compute_query_chunk_attn( + query=get_query_chunk(i * query_chunk_size), + key=key, + value=value, + ) for i in range(math.ceil(q_tokens / query_chunk_size)) + ], dim=1) + return res diff --git a/requirements.txt b/requirements.txt index 5bed694e..0dbea322 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,4 @@ inflection GitPython torchsde safetensors -psutil; sys_platform == 'darwin' +psutil -- cgit v1.2.3 From b119815333026164f2bd7d1ca71f3e4f7a9afd0d Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 5 Jan 2023 04:37:17 -0500 Subject: Use narrow instead of dynamic_slice --- modules/sub_quadratic_attention.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index b11dc1c7..95924d24 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -5,6 +5,7 @@ # credit: # Amin Rezaei (original author) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# brkirch (modified to use torch.narrow instead of dynamic_slice implementation) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 @@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint import math from typing import Optional, NamedTuple, Protocol, List -def dynamic_slice( - x: Tensor, - starts: List[int], - sizes: List[int], +def narrow_trunc( + input: Tensor, + dim: int, + start: int, + length: int ) -> Tensor: - slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] - return x[slicing] + return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) class AttnChunk(NamedTuple): exp_values: Tensor @@ -76,15 +77,17 @@ def _query_chunk_attention( _, _, v_channels_per_head = value.shape def chunk_scanner(chunk_idx: int) -> AttnChunk: - key_chunk = dynamic_slice( + key_chunk = narrow_trunc( key, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, k_channels_per_head) + 1, + chunk_idx, + kv_chunk_size ) - value_chunk = dynamic_slice( + value_chunk = narrow_trunc( value, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, v_channels_per_head) + 1, + chunk_idx, + kv_chunk_size ) return summarize_chunk(query, key_chunk, value_chunk) @@ -161,10 +164,11 @@ def efficient_dot_product_attention( kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) def get_query_chunk(chunk_idx: int) -> Tensor: - return dynamic_slice( + return narrow_trunc( query, - (0, chunk_idx, 0), - (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + 1, + chunk_idx, + min(query_chunk_size, q_tokens) ) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) -- cgit v1.2.3 From 683287d87f6401083a8d63eedc00ca7410214ca1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 08:52:06 +0300 Subject: rework saving training params to file #6372 --- modules/hypernetworks/hypernetwork.py | 28 +++++++------------------- modules/shared.py | 2 +- modules/textual_inversion/logging.py | 24 ++++++++++++++++++++++ modules/textual_inversion/textual_inversion.py | 23 +++------------------ 4 files changed, 35 insertions(+), 42 deletions(-) create mode 100644 modules/textual_inversion/logging.py diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3237c37a..b0cfbe71 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -13,7 +13,7 @@ import tqdm from einops import rearrange, repeat from ldm.util import default from modules import devices, processing, sd_models, shared, sd_samplers -from modules.textual_inversion import textual_inversion +from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_ @@ -401,25 +401,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, hypernet.save(fn) shared.reload_hypernetworks() -# Note: textual_inversion.py has a nearly identical function of the same name. -def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, hypernetwork_name, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # Starting index of preview-related arguments. - border_index = 21 - # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - # Create a list of the argument names to include in the settings string. - names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. - # Build the settings string. - settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" - for name in names: - if name != 'log_directory': # It's useless and redundant to save log_directory. - value = locals()[name] - settings_str += f"{name}: {value}\n" - # Create or append to the file. - with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: - fout.write(settings_str + "\n\n") + def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. @@ -477,7 +459,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), hypernetwork_name, hypernetwork.layer_structure, hypernetwork.activation_func, hypernetwork.weight_init, hypernetwork.add_layer_norm, hypernetwork.use_dropout, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + saved_params = dict( + model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), + **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]} + ) + logging.save_settings_to_file(log_directory, {**saved_params, **locals()}) latent_sampling_method = ds.latent_sampling_method diff --git a/modules/shared.py b/modules/shared.py index f0e10b35..57e489d0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -362,7 +362,7 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), - "save_training_settings_to_txt": OptionInfo(False, "Save textual inversion and hypernet settings to a text file whenever training starts."), + "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py new file mode 100644 index 00000000..8b1981d5 --- /dev/null +++ b/modules/textual_inversion/logging.py @@ -0,0 +1,24 @@ +import datetime +import json +import os + +saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} +saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} +saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} +saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet +saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"} + + +def save_settings_to_file(log_directory, all_params): + now = datetime.datetime.now() + params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} + + keys = saved_params_all + if all_params.get('preview_from_txt2img'): + keys = keys | saved_params_previews + + params.update({k: v for k, v in all_params.items() if k in keys}) + + filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' + with open(os.path.join(log_directory, filename), "w") as file: + json.dump(params, file, indent=4) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e9cf432f..f9f5e8cd 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay) +from modules.textual_inversion.logging import save_settings_to_file + class Embedding: def __init__(self, vec, name, step=None): @@ -231,25 +233,6 @@ def write_loss(log_directory, filename, step, epoch_len, values): **values, }) -# Note: hypernetwork.py has a nearly identical function of the same name. -def save_settings_to_file(model_name, model_hash, initial_step, num_of_dataset_images, embedding_name, vectors_per_token, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # Starting index of preview-related arguments. - border_index = 18 - # Get a list of the argument names. - arg_names = inspect.getfullargspec(save_settings_to_file).args - # Create a list of the argument names to include in the settings string. - names = arg_names[:border_index] # Include all arguments up until the preview-related ones. - if preview_from_txt2img: - names.extend(arg_names[border_index:]) # Include preview-related arguments if applicable. - # Build the settings string. - settings_str = "datetime : " + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "\n" - for name in names: - if name != 'log_directory': # It's useless and redundant to save log_directory. - value = locals()[name] - settings_str += f"{name}: {value}\n" - # Create or append to the file. - with open(os.path.join(log_directory, 'settings.txt'), "a+") as fout: - fout.write(settings_str + "\n\n") def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" @@ -330,7 +313,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(checkpoint.model_name, '[{}]'.format(checkpoint.hash), initial_step, len(ds), embedding_name, len(embedding.vec), learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height) + save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) latent_sampling_method = ds.latent_sampling_method -- cgit v1.2.3 From b95a4c0ce5ab9c414e0494193bfff665f45e9e65 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 01:01:51 -0500 Subject: Change sub-quad chunk threshold to use percentage --- modules/sd_hijack_optimizations.py | 18 +++++++++--------- modules/shared.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index f5c153e8..b416e9ac 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -233,7 +233,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) - x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) @@ -243,20 +243,20 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x -def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold_bytes=None, use_checkpoint=True): +def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 batch_x_heads, q_tokens, _ = q.shape _, k_tokens, _ = k.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - available_vram = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) - - if chunk_threshold_bytes is None: - chunk_threshold_bytes = available_vram - elif chunk_threshold_bytes == 0: + if chunk_threshold is None: + chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7) + elif chunk_threshold == 0: chunk_threshold_bytes = None + else: + chunk_threshold_bytes = int(0.01 * chunk_threshold * get_available_vram()) - if kv_chunk_size_min is None: + if kv_chunk_size_min is None and chunk_threshold_bytes is not None: kv_chunk_size_min = chunk_threshold_bytes // (batch_x_heads * bytes_per_token * (k.shape[2] + v.shape[2])) elif kv_chunk_size_min == 0: kv_chunk_size_min = None @@ -382,7 +382,7 @@ def sub_quad_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold_bytes=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + out = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index cb1dc312..d7a81db1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -59,7 +59,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024) parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None) -parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the size threshold in bytes for the sub-quadratic cross-attention layer optimization to use chunking", default=None) +parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -- cgit v1.2.3 From 5deb2a19ccea57a50252e8fcb07b4d17c6599def Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 01:33:15 -0500 Subject: Allow Doggettx's cross attention opt without CUDA --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ef25dadb..bd101e5b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -50,7 +50,7 @@ def apply_optimizations(): print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 optimization_method = 'V1' - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()): print("Applying cross attention optimization (InvokeAI).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI optimization_method = 'InvokeAI' -- cgit v1.2.3 From c9bded39ee05bd0507ccd27d2b674d86d6c0c8e8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 12:32:44 +0300 Subject: sort extensions by date and add an option to sort by other columns --- modules/ui_extensions.py | 44 ++++++++++++++++++++++++++++++++------------ style.css | 11 ++++++++++- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index eec9586f..742e745e 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -162,15 +162,15 @@ def install_extension_from_url(dirname, url): shutil.rmtree(tmpdir, True) -def install_extension_from_index(url, hide_tags): +def install_extension_from_index(url, hide_tags, sort_column): ext_table, message = install_extension_from_url(None, url) - code, _ = refresh_available_extensions_from_data(hide_tags) + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column) return code, ext_table, message -def refresh_available_extensions(url, hide_tags): +def refresh_available_extensions(url, hide_tags, sort_column): global available_extensions import urllib.request @@ -179,18 +179,28 @@ def refresh_available_extensions(url, hide_tags): available_extensions = json.loads(text) - code, tags = refresh_available_extensions_from_data(hide_tags) + code, tags = refresh_available_extensions_from_data(hide_tags, sort_column) return url, code, gr.CheckboxGroup.update(choices=tags), '' -def refresh_available_extensions_for_tags(hide_tags): - code, _ = refresh_available_extensions_from_data(hide_tags) +def refresh_available_extensions_for_tags(hide_tags, sort_column): + code, _ = refresh_available_extensions_from_data(hide_tags, sort_column) return code, '' -def refresh_available_extensions_from_data(hide_tags): +sort_ordering = [ + # (reverse, order_by_function) + (True, lambda x: x.get('added', 'z')), + (False, lambda x: x.get('added', 'z')), + (False, lambda x: x.get('name', 'z')), + (True, lambda x: x.get('name', 'z')), + (False, lambda x: 'z'), +] + + +def refresh_available_extensions_from_data(hide_tags, sort_column): extlist = available_extensions["extensions"] installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} @@ -210,8 +220,11 @@ def refresh_available_extensions_from_data(hide_tags): """ - for ext in extlist: + sort_reverse, sort_function = sort_ordering[sort_column if 0 <= sort_column < len(sort_ordering) else 0] + + for ext in sorted(extlist, key=sort_function, reverse=sort_reverse): name = ext.get("name", "noname") + added = ext.get('added', 'unknown') url = ext.get("url", None) description = ext.get("description", "") extension_tags = ext.get("tags", []) @@ -233,7 +246,7 @@ def refresh_available_extensions_from_data(hide_tags): code += f""" {html.escape(name)}
{tags_text} - {html.escape(description)} + {html.escape(description)}

Added: {html.escape(added)}

{install_code} @@ -291,25 +304,32 @@ def create_ui(): with gr.Row(): hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) + sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order", ], type="index") install_result = gr.HTML() available_extensions_table = gr.HTML() refresh_available_extensions_button.click( fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), - inputs=[available_extensions_index, hide_tags], + inputs=[available_extensions_index, hide_tags, sort_column], outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result], ) install_extension_button.click( fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), - inputs=[extension_to_install, hide_tags], + inputs=[extension_to_install, hide_tags, sort_column], outputs=[available_extensions_table, extensions_table, install_result], ) hide_tags.change( fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), - inputs=[hide_tags], + inputs=[hide_tags, sort_column], + outputs=[available_extensions_table, install_result] + ) + + sort_column.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags, sort_column], outputs=[available_extensions_table, install_result] ) diff --git a/style.css b/style.css index ee74d79e..f1b23b53 100644 --- a/style.css +++ b/style.css @@ -555,7 +555,7 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h /* Extensions */ -#tab_extensions table{ +#tab_extensions table``{ border-collapse: collapse; } @@ -581,6 +581,15 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h font-size: 95%; } +#available_extensions .info{ + margin: 0; +} + +#available_extensions .date_added{ + opacity: 0.85; + font-size: 90%; +} + #image_buttons_txt2img button, #image_buttons_img2img button, #image_buttons_extras button{ min-width: auto; padding-left: 0.5em; -- cgit v1.2.3 From 65ed4421e609dda3112f236c13e4db14caa71364 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 13:55:50 +0300 Subject: add callback for when the script is unloaded --- modules/script_callbacks.py | 18 +++++++++++++++++- webui.py | 2 ++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index de69fd9f..608c5300 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_script_unloaded=[], ) @@ -171,6 +172,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def script_unloaded_callback(): + for c in reversed(callback_map['callbacks_script_unloaded']): + try: + c.callback() + except Exception: + report_exception(c, 'script_unloaded') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -202,7 +211,7 @@ def on_app_started(callback): def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is - passed as an argument""" + passed as an argument; this function is also called when the script is reloaded. """ add_callback(callback_map['callbacks_model_loaded'], callback) @@ -279,3 +288,10 @@ def on_image_grid(callback): - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ add_callback(callback_map['callbacks_image_grid'], callback) + + +def on_script_unloaded(callback): + """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that + the script did should be reverted here""" + + add_callback(callback_map['callbacks_script_unloaded'], callback) diff --git a/webui.py b/webui.py index ff6eb6eb..733a06b5 100644 --- a/webui.py +++ b/webui.py @@ -187,12 +187,14 @@ def webui(): sd_samplers.set_samplers() + modules.script_callbacks.script_unloaded_callback() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) modelloader.forbid_loaded_nonbuiltin_upscalers() modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) modelloader.load_upscalers() for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: -- cgit v1.2.3 From 848605fb654a55ee6947335d7df6e13366606fad Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 06:58:49 -0500 Subject: Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c1944d33..fea6cb35 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) -- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san), Amin Rezaei (https://github.com/AminRezaei0x443) +- Sub-quadratic Cross Attention layer optimization - Alex Birch (https://github.com/Birch-san/diffusers/pull/1), Amin Rezaei (https://github.com/AminRezaei0x443/memory-efficient-attention) - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot -- cgit v1.2.3 From 5e6566324bba20554bcc04f3dda798e560397f38 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 07:06:26 -0500 Subject: Always end version number with a digit --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.py b/webui.py index 733a06b5..8737e593 100644 --- a/webui.py +++ b/webui.py @@ -16,7 +16,7 @@ from modules.paths import script_path import torch # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: - torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0) + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer -- cgit v1.2.3 From 3246a2d6b898da6a98fe9df4dc67944635a41bd3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 16:03:43 +0300 Subject: remove restriction for saving dropdowns to ui-config.json --- modules/scripts.py | 1 - modules/ui.py | 10 ++-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 0c44f191..35164093 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -290,7 +290,6 @@ class ScriptRunner: script.group = group dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") - dropdown.save_to_config = True inputs[0] = dropdown for script in self.selectable_scripts: diff --git a/modules/ui.py b/modules/ui.py index 030f0685..b79d24ee 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -435,11 +435,9 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=1, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - prompt_style.save_to_config = True with gr.Column(scale=1, elem_id="style_neg_col"): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - prompt_style2.save_to_config = True return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button @@ -638,7 +636,6 @@ def create_sampler_and_steps_selection(choices, tabname): if opts.samplers_in_dropdown: with FormRow(elem_id=f"sampler_selection_{tabname}"): sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - sampler_index.save_to_config = True steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) else: with FormGroup(elem_id=f"sampler_selection_{tabname}"): @@ -1794,7 +1791,7 @@ def create_ui(): if init_field is not None: init_field(saved_value) - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number] and x.visible: + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: apply_field(x, 'visible') if type(x) == gr.Slider: @@ -1815,11 +1812,8 @@ def create_ui(): if type(x) == gr.Number: apply_field(x, 'value') - # Since there are many dropdowns that shouldn't be saved, - # we only mark dropdowns that should be saved. - if type(x) == gr.Dropdown and getattr(x, 'save_to_config', False): + if type(x) == gr.Dropdown: apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - apply_field(x, 'visible') visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") -- cgit v1.2.3 From 50194de93ffc9db763d9b08fcc9c3bde1aa86151 Mon Sep 17 00:00:00 2001 From: Kuma <36082288+KumiIT@users.noreply.github.com> Date: Fri, 6 Jan 2023 16:12:45 +0100 Subject: typo UI fixes #6391 --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 57e489d0..865c3c07 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -430,7 +430,7 @@ options_templates.update(options_section(('ui', "User interface"), { "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), - 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/ing2img UI item order"), + 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) -- cgit v1.2.3 From 3992ecbe6e46a465062508c677964534e7397f72 Mon Sep 17 00:00:00 2001 From: Mitchell Boot <47387831+Mitchell1711@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:02:46 +0100 Subject: Added UI elements Added a new row to hires fix that shows the new resolution after scaling --- modules/ui.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index b79d24ee..20f7d2a2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -255,6 +255,12 @@ def add_style(name: str, prompt: str, negative_prompt: str): return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] +def calc_resolution_hires(x, y, scale): + #final res can only be a multiple of 8 + scaled_x = int(x * scale // 8) * 8 + scaled_y = int(y * scale // 8) * 8 + + return "

Upscaled Resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) @@ -718,6 +724,12 @@ def create_ui(): hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + with FormRow(elem_id="txt2img_hires_fix_row3"): + hr_final_resolution = gr.HTML(value=calc_resolution_hires(width.value, height.value, hr_scale.value), elem_id="txtimg_hr_finalres") + hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) elif category == "batch": if not opts.dimensions_and_batch_together: -- cgit v1.2.3 From 991368c8d54404d8e13d4c6e76a0f32644e65ad4 Mon Sep 17 00:00:00 2001 From: Mitchell Boot <47387831+Mitchell1711@users.noreply.github.com> Date: Fri, 6 Jan 2023 18:24:29 +0100 Subject: remove camelcase --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 20f7d2a2..6fc8b7d7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -260,7 +260,7 @@ def calc_resolution_hires(x, y, scale): scaled_x = int(x * scale // 8) * 8 scaled_y = int(y * scale // 8) * 8 - return "

Upscaled Resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" + return "

Upscaled resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) -- cgit v1.2.3 From c18add68ef7d2de3617cbbaff864b0c74cfdf6c0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 16:42:47 -0500 Subject: Added license --- html/licenses.html | 29 ++++++++++++++++++++++++++++- modules/sd_hijack_optimizations.py | 1 + modules/sub_quadratic_attention.py | 2 +- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/html/licenses.html b/html/licenses.html index 9eeaa072..570630eb 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -184,7 +184,7 @@ SOFTWARE.

SwinIR

-Code added by contirubtors, most likely copied from this repository. +Code added by contributors, most likely copied from this repository.
                                  Apache License
@@ -390,3 +390,30 @@ SOFTWARE.
    limitations under the License.
 
+

Memory Efficient Attention

+The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that. +
+MIT License
+
+Copyright (c) 2023 Alex Birch
+Copyright (c) 2023 Amin Rezaei
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+ diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index b416e9ac..cdc63ed7 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -216,6 +216,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface def sub_quad_attention_forward(self, x, context=None, mask=None): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 95924d24..fea7aaac 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -1,7 +1,7 @@ # original source: # https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py # license: -# unspecified +# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license) # credit: # Amin Rezaei (original author) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) -- cgit v1.2.3 From 82c1f10b144f733460feead0bdc37a861489dc57 Mon Sep 17 00:00:00 2001 From: Dean Hopkins Date: Fri, 6 Jan 2023 22:00:12 +0000 Subject: increase upscale api validation limit --- modules/api/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/models.py b/modules/api/models.py index f77951fc..22b88c59 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -125,7 +125,7 @@ class ExtrasBaseRequest(BaseModel): gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.") codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.") codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.") - upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.") + upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=8, description="By how much to upscale the image, only used when resize_mode=0.") upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.") upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.") upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?") -- cgit v1.2.3 From 79e39fae6110c20a3ee6255e2841c877f65e8cbd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 01:45:28 +0300 Subject: CLIP hijack rework --- modules/sd_hijack.py | 6 +- modules/sd_hijack_clip.py | 348 ++++++++++++------------- modules/sd_hijack_clip_old.py | 81 ++++++ modules/textual_inversion/textual_inversion.py | 1 - modules/ui.py | 2 +- 5 files changed, 256 insertions(+), 182 deletions(-) create mode 100644 modules/sd_hijack_clip_old.py diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fa2cd4bb..71cc145a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -150,10 +150,10 @@ class StableDiffusionModelHijack: def clear_comments(self): self.comments = [] - def tokenize(self, text): - _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) + def get_prompt_lengths(self, text): + _, token_count = self.clip.process_texts([text]) - return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count) + return token_count, self.clip.get_target_prompt_token_count(token_count) class EmbeddingsWithFixes(torch.nn.Module): diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index ca92b142..ac3020d7 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -1,12 +1,28 @@ import math +from collections import namedtuple import torch from modules import prompt_parser, devices from modules.shared import opts -def get_target_prompt_token_count(token_count): - return math.ceil(max(token_count, 1) / 75) * 75 + +class PromptChunk: + """ + This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt. + If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary. + Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token, + so just 75 tokens from prompt. + """ + + def __init__(self): + self.tokens = [] + self.multipliers = [] + self.fixes = [] + + +PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) +"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk""" class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): @@ -14,17 +30,49 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): super().__init__() self.wrapped = wrapped self.hijack = hijack + self.chunk_length = 75 + + def empty_chunk(self): + """creates an empty PromptChunk and returns it""" + + chunk = PromptChunk() + chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) + chunk.multipliers = [1.0] * (self.chunk_length + 2) + return chunk + + def get_target_prompt_token_count(self, token_count): + """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented""" + + return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length def tokenize(self, texts): + """Converts a batch of texts into a batch of token ids""" + raise NotImplementedError def encode_with_transformers(self, tokens): + """ + converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; + All python lists with tokens are assumed to have same length, usually 77. + if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on + model - can be 768 and 1024 + """ + raise NotImplementedError def encode_embedding_init_text(self, init_text, nvpt): + """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through + transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned.""" + raise NotImplementedError - def tokenize_line(self, line, used_custom_terms, hijack_comments): + def tokenize_line(self, line): + """ + this transforms a single prompt into a list of PromptChunk objects - as many as needed to + represent the prompt. + Returns the list and the total number of tokens in the prompt. + """ + if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) else: @@ -32,205 +80,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): tokenized = self.tokenize([text for text, _ in parsed]) - fixes = [] - remade_tokens = [] - multipliers = [] + chunks = [] + chunk = PromptChunk() + token_count = 0 last_comma = -1 - for tokens, (text, weight) in zip(tokenized, parsed): - i = 0 - while i < len(tokens): - token = tokens[i] + def next_chunk(): + """puts current chunk into the list of results and produces the next one - empty""" + nonlocal token_count + nonlocal last_comma + nonlocal chunk + + token_count += len(chunk.tokens) + to_add = self.chunk_length - len(chunk.tokens) + if to_add > 0: + chunk.tokens += [self.id_end] * to_add + chunk.multipliers += [1.0] * to_add - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end] + chunk.multipliers = [1.0] + chunk.multipliers + [1.0] + + last_comma = -1 + chunks.append(chunk) + chunk = PromptChunk() + + for tokens, (text, weight) in zip(tokenized, parsed): + position = 0 + while position < len(tokens): + token = tokens[position] if token == self.comma_token: - last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: - last_comma += 1 - reloc_tokens = remade_tokens[last_comma:] - reloc_mults = multipliers[last_comma:] + last_comma = len(chunk.tokens) + + # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack + # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next. + elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: + break_location = last_comma + 1 + + reloc_tokens = chunk.tokens[break_location:] + reloc_mults = chunk.multipliers[break_location:] - remade_tokens = remade_tokens[:last_comma] - length = len(remade_tokens) + chunk.tokens = chunk.tokens[:break_location] + chunk.multipliers = chunk.multipliers[:break_location] - rem = int(math.ceil(length / 75)) * 75 - length - remade_tokens += [self.id_end] * rem + reloc_tokens - multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + next_chunk() + chunk.tokens = reloc_tokens + chunk.multipliers = reloc_mults + if len(chunk.tokens) == self.chunk_length: + next_chunk() + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) if embedding is None: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - iteration = len(remade_tokens) // 75 - if (len(remade_tokens) + emb_len) // 75 != iteration: - rem = (75 * (iteration + 1) - len(remade_tokens)) - remade_tokens += [self.id_end] * rem - multipliers += [1.0] * rem - iteration += 1 - fixes.append((iteration, (len(remade_tokens) % 75, embedding))) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - token_count = len(remade_tokens) - prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) - - remade_tokens = remade_tokens + [self.id_end] * tokens_to_add - multipliers = multipliers + [1.0] * tokens_to_add - - return remade_tokens, fixes, multipliers, token_count - - def process_text(self, texts): - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] + chunk.tokens.append(token) + chunk.multipliers.append(weight) + position += 1 + continue + + emb_len = int(embedding.vec.shape[0]) + if len(chunk.tokens) + emb_len > self.chunk_length: + next_chunk() + + chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) + + chunk.tokens += [0] * emb_len + chunk.multipliers += [weight] * emb_len + position += embedding_length_in_tokens + + if len(chunk.tokens) > 0: + next_chunk() + + return chunks, token_count + + def process_texts(self, texts): + """ + Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum + length, in tokens, of all texts. + """ + token_count = 0 cache = {} - batch_multipliers = [] + batch_chunks = [] for line in texts: if line in cache: - remade_tokens, fixes, multipliers = cache[line] + chunks = cache[line] else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + chunks, current_token_count = self.tokenize_line(line) token_count = max(current_token_count, token_count) - cache[line] = (remade_tokens, fixes, multipliers) + cache[line] = chunks - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) + batch_chunks.append(chunks) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + return batch_chunks, token_count - def process_text_old(self, texts): - id_start = self.id_start - id_end = self.id_end - maxlen = self.wrapped.max_length # you get to stay at 77 - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] - token_count = 0 + def forward(self, texts): + """ + Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. + Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. + An example shape returned by this function can be: (2, 77, 768). + Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet + is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" + """ - cache = {} - batch_tokens = self.tokenize(texts) - batch_multipliers = [] - for tokens in batch_tokens: - tuple_tokens = tuple(tokens) + if opts.use_old_emphasis_implementation: + import modules.sd_hijack_clip_old + return modules.sd_hijack_clip_old.forward_old(self, texts) - if tuple_tokens in cache: - remade_tokens, fixes, multipliers = cache[tuple_tokens] - else: - fixes = [] - remade_tokens = [] - multipliers = [] - mult = 1.0 - - i = 0 - while i < len(tokens): - token = tokens[i] - - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - - mult_change = self.token_mults.get(token) if opts.enable_emphasis else None - if mult_change is not None: - mult *= mult_change - i += 1 - elif embedding is None: - remade_tokens.append(token) - multipliers.append(mult) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] - cache[tuple_tokens] = (remade_tokens, fixes, multipliers) - - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] - - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - - def forward(self, text): - use_old = opts.use_old_emphasis_implementation - if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) - else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - - self.hijack.comments += hijack_comments - - if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - - if use_old: - self.hijack.fixes = hijack_fixes - return self.process_tokens(remade_batch_tokens, batch_multipliers) - - z = None - i = 0 - while max(map(len, remade_batch_tokens)) != 0: - rem_tokens = [x[75:] for x in remade_batch_tokens] - rem_multipliers = [x[75:] for x in batch_multipliers] - - self.hijack.fixes = [] - for unfiltered in hijack_fixes: - fixes = [] - for fix in unfiltered: - if fix[0] == i: - fixes.append(fix[1]) - self.hijack.fixes.append(fixes) - - tokens = [] - multipliers = [] - for j in range(len(remade_batch_tokens)): - if len(remade_batch_tokens[j]) > 0: - tokens.append(remade_batch_tokens[j][:75]) - multipliers.append(batch_multipliers[j][:75]) - else: - tokens.append([self.id_end] * 75) - multipliers.append([1.0] * 75) - - z1 = self.process_tokens(tokens, multipliers) - z = z1 if z is None else torch.cat((z, z1), axis=-2) - - remade_batch_tokens = rem_tokens - batch_multipliers = rem_multipliers - i += 1 + batch_chunks, token_count = self.process_texts(texts) - return z + used_embeddings = {} + chunk_count = max([len(x) for x in batch_chunks]) - def process_tokens(self, remade_batch_tokens, batch_multipliers): - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens] - batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + zs = [] + for i in range(chunk_count): + batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] + + tokens = [x.tokens for x in batch_chunk] + multipliers = [x.multipliers for x in batch_chunk] + self.hijack.fixes = [x.fixes for x in batch_chunk] + for fixes in self.hijack.fixes: + for position, embedding in fixes: + used_embeddings[embedding.name] = embedding + + z = self.process_tokens(tokens, multipliers) + zs.append(z) + + if len(used_embeddings) > 0: + embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) + self.hijack.comments.append(f"Used embeddings: {embeddings_list}") + + return torch.hstack(zs) + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + """ + sends one single prompt chunk to be encoded by transformers neural network. + remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually + there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens. + Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier + corresponds to one token. + """ tokens = torch.asarray(remade_batch_tokens).to(devices.device) + # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones. if self.id_end != self.id_pad: for batch_pos in range(len(remade_batch_tokens)): index = remade_batch_tokens[batch_pos].index(self.id_end) @@ -239,8 +234,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): z = self.encode_with_transformers(tokens) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] - batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device) + batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py new file mode 100644 index 00000000..6d9fbbe6 --- /dev/null +++ b/modules/sd_hijack_clip_old.py @@ -0,0 +1,81 @@ +from modules import sd_hijack_clip +from modules import shared + + +def process_text_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): + id_start = self.id_start + id_end = self.id_end + maxlen = self.wrapped.max_length # you get to stay at 77 + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_tokens = self.tokenize(texts) + batch_multipliers = [] + for tokens in batch_tokens: + tuple_tokens = tuple(tokens) + + if tuple_tokens in cache: + remade_tokens, fixes, multipliers = cache[tuple_tokens] + else: + fixes = [] + remade_tokens = [] + multipliers = [] + mult = 1.0 + + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + mult_change = self.token_mults.get(token) if shared.opts.enable_emphasis else None + if mult_change is not None: + mult *= mult_change + i += 1 + elif embedding is None: + remade_tokens.append(token) + multipliers.append(mult) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + +def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, texts): + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = process_text_old(self, texts) + + self.hijack.comments += hijack_comments + + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f9f5e8cd..45882ed6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -79,7 +79,6 @@ class EmbeddingDatabase: self.word_embeddings[embedding.name] = embedding - # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working ids = model.cond_stage_model.tokenize([embedding.name])[0] first_id = ids[0] diff --git a/modules/ui.py b/modules/ui.py index b79d24ee..5d2f5bad 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -368,7 +368,7 @@ def update_token_counter(text, steps): flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] - tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1]) + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) style_class = ' class="red"' if (token_count > max_length) else "" return f"{token_count}/{max_length}" -- cgit v1.2.3 From f94cfc563bbedd923d5e95563a5e8d93c8516ac3 Mon Sep 17 00:00:00 2001 From: Mitchell Boot <47387831+Mitchell1711@users.noreply.github.com> Date: Sat, 7 Jan 2023 01:15:22 +0100 Subject: Changed HTML to textbox instead Using HTML caused an issue where the row would expand for a frame when changing the sliders because of the loading animation. This solution also doesn't use any additional HTML padding --- modules/ui.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 6fc8b7d7..6ea1b5d7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -260,7 +260,7 @@ def calc_resolution_hires(x, y, scale): scaled_x = int(x * scale // 8) * 8 scaled_y = int(y * scale // 8) * 8 - return "

Upscaled resolution: "+str(scaled_x)+"x"+str(scaled_y)+"

" + return str(scaled_x)+"x"+str(scaled_y) def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) @@ -726,7 +726,10 @@ def create_ui(): hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") with FormRow(elem_id="txt2img_hires_fix_row3"): - hr_final_resolution = gr.HTML(value=calc_resolution_hires(width.value, height.value, hr_scale.value), elem_id="txtimg_hr_finalres") + hr_final_resolution = gr.Textbox(value=calc_resolution_hires(width.value, height.value, hr_scale.value), + elem_id="txtimg_hr_finalres", + label="Upscaled resolution", + interactive=False) hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) -- cgit v1.2.3 From 08066676a47b560235d4c085dd3cfcb470b80997 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 07:22:07 +0300 Subject: make it not break on empty inputs; thank you tarded, we are --- modules/sd_hijack_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index ac3020d7..16aef76a 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -147,7 +147,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk.multipliers += [weight] * emb_len position += embedding_length_in_tokens - if len(chunk.tokens) > 0: + if len(chunk.tokens) > 0 or len(chunks) == 0: next_chunk() return chunks, token_count -- cgit v1.2.3 From 1740c33547b62f692834c95914a2b295d51684c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 07:48:44 +0300 Subject: more comments --- modules/sd_hijack_clip.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 16aef76a..5520c9b2 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -3,7 +3,7 @@ from collections import namedtuple import torch -from modules import prompt_parser, devices +from modules import prompt_parser, devices, sd_hijack from modules.shared import opts @@ -22,14 +22,24 @@ class PromptChunk: PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) -"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk""" +"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt +chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally +are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): + """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to + have unlimited prompt length and assign weights to tokens in prompt. + """ + def __init__(self, wrapped, hijack): super().__init__() + self.wrapped = wrapped - self.hijack = hijack + """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, + depending on model.""" + + self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.chunk_length = 75 def empty_chunk(self): @@ -55,7 +65,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; All python lists with tokens are assumed to have same length, usually 77. if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on - model - can be 768 and 1024 + model - can be 768 and 1024. + Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None). """ raise NotImplementedError @@ -113,7 +124,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): last_comma = len(chunk.tokens) # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack - # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next. + # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: break_location = last_comma + 1 -- cgit v1.2.3 From de9738044571877450d1038e18f1ecce93d24af3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 08:53:53 +0300 Subject: this breaks on default config because width, height, hr_scale are None at that point. --- modules/ui.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index f946382d..a18b9007 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -725,14 +725,8 @@ def create_ui(): hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - with FormRow(elem_id="txt2img_hires_fix_row3"): - hr_final_resolution = gr.Textbox(value=calc_resolution_hires(width.value, height.value, hr_scale.value), - elem_id="txtimg_hr_finalres", - label="Upscaled resolution", - interactive=False) - hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + with FormRow(elem_id="txt2img_hires_fix_row3"): + hr_final_resolution = gr.Textbox(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) elif category == "batch": if not opts.dimensions_and_batch_together: @@ -744,6 +738,10 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) -- cgit v1.2.3 From 1a5b86ad65fd738eadea1ad72f4abad3a4aabf17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 09:56:37 +0300 Subject: rework hires fix preview for #6437: movie it to where it takes less place, make it actually account for all relevant sliders and calculate dimensions correctly --- modules/processing.py | 1 - modules/ui.py | 40 +++++++++++++++++++++++++++------------- modules/ui_components.py | 8 ++++++++ style.css | 17 +++++++++++++++++ 4 files changed, 52 insertions(+), 14 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index a408d622..82157bc9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -711,7 +711,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = 0 self.truncate_y = 0 - def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: if self.hr_resize_x == 0 and self.hr_resize_y == 0: diff --git a/modules/ui.py b/modules/ui.py index a18b9007..6c765262 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts @@ -255,12 +255,20 @@ def add_style(name: str, prompt: str, negative_prompt: str): return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] -def calc_resolution_hires(x, y, scale): - #final res can only be a multiple of 8 - scaled_x = int(x * scale // 8) * 8 - scaled_y = int(y * scale // 8) * 8 - - return str(scaled_x)+"x"+str(scaled_y) + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"resize to: {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + def apply_styles(prompt, prompt_neg, style1_name, style2_name): prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) @@ -712,6 +720,7 @@ def create_ui(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) elif category == "hires_fix": with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: @@ -724,9 +733,6 @@ def create_ui(): hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - with FormRow(elem_id="txt2img_hires_fix_row3"): - hr_final_resolution = gr.Textbox(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) elif category == "batch": if not opts.dimensions_and_batch_together: @@ -738,9 +744,16 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - hr_scale.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - width.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) - height.change(fn=calc_resolution_hires, inputs=[width, height, hr_scale], outputs=hr_final_resolution, show_progress=False) + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + hr_resolution_preview_args = dict( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False + ) + + for input in hr_resolution_preview_inputs: + input.change(**hr_resolution_preview_args) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) @@ -803,6 +816,7 @@ def create_ui(): fn=lambda x: gr_show(x), inputs=[enable_hr], outputs=[hr_options], + show_progress = False, ) txt2img_paste_fields = [ diff --git a/modules/ui_components.py b/modules/ui_components.py index 91eb0e3d..cac001dc 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -23,3 +23,11 @@ class FormGroup(gr.Group, gr.components.FormComponent): def get_block_name(self): return "group" + + +class FormHTML(gr.HTML, gr.components.FormComponent): + """Same as gr.HTML but fits inside gradio forms""" + + def get_block_name(self): + return "html" + diff --git a/style.css b/style.css index f1b23b53..76721756 100644 --- a/style.css +++ b/style.css @@ -642,6 +642,23 @@ footer { opacity: 0.85; } +#txtimg_hr_finalres{ + min-height: 0 !important; + padding: .625rem .75rem; + margin-left: -0.75em + +} + +#txtimg_hr_finalres .resolution{ + font-weight: bold; +} + +#txt2img_checkboxes > div > div{ + flex: 0; + white-space: nowrap; + min-width: auto; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From a36e2744e2b18a2582247bc5b95bfa0339dfa629 Mon Sep 17 00:00:00 2001 From: Taithrah Date: Sat, 7 Jan 2023 04:09:02 -0500 Subject: Update hints.js Small touch up to hints --- javascript/hints.js | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/javascript/hints.js b/javascript/hints.js index dda66e09..73ab4a26 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -4,7 +4,7 @@ titles = { "Sampling steps": "How many times to improve the generated image iteratively; higher values take longer; very low values can produce bad results", "Sampling method": "Which algorithm to use to produce the image", "GFPGAN": "Restore low quality faces using GFPGAN neural network", - "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps to higher than 30-40 does not help", + "Euler a": "Euler Ancestral - very creative, each can get a completely different picture depending on step count, setting steps higher than 30-40 does not help", "DDIM": "Denoising Diffusion Implicit Models - best at inpainting", "DPM adaptive": "Ignores step count - uses a number of steps determined by the CFG and resolution", @@ -12,8 +12,8 @@ titles = { "Batch size": "How many image to create in a single batch", "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results", "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", - "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", - "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", + "\u{1f3b2}\ufe0f": "Set seed to -1 will set a new random number every time.", + "\u267b\ufe0f": "Reuse seed from last generation, most useful if it was randomized.", "\u{1f3a8}": "Add a random artist to the prompt.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u{1f4c2}": "Open images output directory", @@ -74,7 +74,7 @@ titles = { "Style 1": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both", "Apply style": "Insert selected styles into prompt fields", - "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.", + "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style uses that as a placeholder for your prompt when you use the style in the future.", "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.", "Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.", @@ -92,12 +92,12 @@ titles = { "Weighted sum": "Result = A * (1 - M) + B * M", "Add difference": "Result = A + (B - C) * M", - "Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", + "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.", - "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.", - "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality.", + "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resolution and lower quality.", + "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resolution and extremely low quality.", "Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", "Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.", -- cgit v1.2.3 From 0fc1848e40dbd46c93753a2937403e1139ecd366 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 7 Jan 2023 11:25:41 +0200 Subject: CI: Use native actions/setup-python caching --- .github/workflows/on_pull_request.yaml | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/.github/workflows/on_pull_request.yaml b/.github/workflows/on_pull_request.yaml index b097d180..a168be5b 100644 --- a/.github/workflows/on_pull_request.yaml +++ b/.github/workflows/on_pull_request.yaml @@ -19,22 +19,19 @@ jobs: - name: Checkout Code uses: actions/checkout@v3 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: 3.10.6 - - uses: actions/cache@v2 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} - restore-keys: | - ${{ runner.os }}-pip- + cache: pip + cache-dependency-path: | + **/requirements*txt - name: Install PyLint run: | python -m pip install --upgrade pip pip install pylint # This lets PyLint check to see if it can resolve imports - name: Install dependencies - run : | + run: | export COMMANDLINE_ARGS="--skip-torch-cuda-test --exit" python launch.py - name: Analysing the code with pylint -- cgit v1.2.3 From a77873974b97618351791ea3015639be7d9f98d1 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 7 Jan 2023 11:34:02 +0200 Subject: ... also for tests. --- .github/workflows/run_tests.yaml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 49dc92bd..ecb9012a 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -14,11 +14,9 @@ jobs: uses: actions/setup-python@v4 with: python-version: 3.10.6 - - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} - restore-keys: ${{ runner.os }}-pip- + cache: pip + cache-dependency-path: | + **/requirements*txt - name: Run tests run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test - name: Upload main app stdout-stderr -- cgit v1.2.3 From fdfce4711076c2ebac1089bac8169d043eb7978f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 13:29:47 +0300 Subject: add "from" resolution for hires fix to be less confusing. --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index 6c765262..99483130 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz with devices.autocast(): p.init([""], [0], [0]) - return f"resize to: {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + return f"resize: from {width}x{height} to {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" def apply_styles(prompt, prompt_neg, style1_name, style2_name): -- cgit v1.2.3 From 151233399c4b79934bdbb7c12a97eeb6499572fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 13:30:06 +0300 Subject: new screenshot --- README.md | 9 +++------ screenshot.png | Bin 525075 -> 420577 bytes txt2img_Screenshot.png | Bin 337094 -> 0 bytes 3 files changed, 3 insertions(+), 6 deletions(-) delete mode 100644 txt2img_Screenshot.png diff --git a/README.md b/README.md index fea6cb35..d783fdf0 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,7 @@ # Stable Diffusion web UI A browser interface based on Gradio library for Stable Diffusion. -![](txt2img_Screenshot.png) - -Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users. +![](screenshot.png) ## Features [Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features): @@ -97,9 +95,8 @@ Alternatively, use online services (like Google Colab): 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" 2. Install [git](https://git-scm.com/download/win). 3. Download the stable-diffusion-webui repository, for example by running `git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git`. -4. Place `model.ckpt` in the `models` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it). -5. _*(Optional)*_ Place `GFPGANv1.4.pth` in the base directory, alongside `webui.py` (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it). -6. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user. +4. Place stable diffusion checkpoint (`model.ckpt`) in the `models/Stable-diffusion` directory (see [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) for where to get it). +5. Run `webui-user.bat` from Windows Explorer as normal, non-administrator, user. ### Automatic Installation on Linux 1. Install the dependencies: diff --git a/screenshot.png b/screenshot.png index 86c3209f..47a1be4e 100644 Binary files a/screenshot.png and b/screenshot.png differ diff --git a/txt2img_Screenshot.png b/txt2img_Screenshot.png deleted file mode 100644 index 6e2759a4..00000000 Binary files a/txt2img_Screenshot.png and /dev/null differ -- cgit v1.2.3 From 8a27730da5d5b25e28370e8ad94844856a839af9 Mon Sep 17 00:00:00 2001 From: Taithrah Date: Sat, 7 Jan 2023 06:11:57 -0500 Subject: Update hints.js Partial revert --- javascript/hints.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/javascript/hints.js b/javascript/hints.js index 73ab4a26..856e1389 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -12,8 +12,8 @@ titles = { "Batch size": "How many image to create in a single batch", "CFG Scale": "Classifier Free Guidance Scale - how strongly the image should conform to prompt - lower values produce more creative results", "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", - "\u{1f3b2}\ufe0f": "Set seed to -1 will set a new random number every time.", - "\u267b\ufe0f": "Reuse seed from last generation, most useful if it was randomized.", + "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", + "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u{1f3a8}": "Add a random artist to the prompt.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u{1f4c2}": "Open images output directory", -- cgit v1.2.3 From df3b31eb559ab9fabf7e513bdeddd5282c16f124 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 7 Jan 2023 07:04:59 -0500 Subject: In-place operations can break gradient calculation --- modules/sd_hijack_clip.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 5520c9b2..852afc66 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() - z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() - z *= original_mean / new_mean + z = z * (original_mean / new_mean) return z -- cgit v1.2.3 From 47534577eda63b0db1eeb8921c2a161773ec434c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 7 Jan 2023 07:51:35 -0500 Subject: api-get-memory --- modules/api/api.py | 37 +++++++++++++++++++++++++++++++++++++ modules/api/models.py | 4 ++++ 2 files changed, 41 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index 2103709b..d2222b18 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -130,6 +130,7 @@ class Api: self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=PreprocessResponse) self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse) self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse) + self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse) def add_api_route(self, path: str, endpoint, **kwargs): if shared.cmd_opts.api_auth: @@ -465,6 +466,42 @@ class Api: shared.state.end() return TrainResponse(info = "train embedding error: {error}".format(error = error)) + def get_memory(self): + def gb(val: float): + return round(val / 1024 / 1024 / 1024, 2) + try: + import os, psutil + process = psutil.Process(os.getpid()) + res = process.memory_info() + ram_total = 100 * res.rss / process.memory_percent() + ram = { 'free': gb(ram_total - res.rss), 'used': gb(res.rss), 'total': gb(ram_total) } + except Exception as err: + ram = { 'error': f'{err}' } + try: + import torch + if torch.cuda.is_available(): + s = torch.cuda.mem_get_info() + system = { 'free': gb(s[0]), 'used': gb(s[1] - s[0]), 'total': gb(s[1]) } + s = dict(torch.cuda.memory_stats(shared.device)) + allocated = { 'current': gb(s['allocated_bytes.all.current']), 'peak': gb(s['allocated_bytes.all.peak']) } + reserved = { 'current': gb(s['reserved_bytes.all.current']), 'peak': gb(s['reserved_bytes.all.peak']) } + active = { 'current': gb(s['active_bytes.all.current']), 'peak': gb(s['active_bytes.all.peak']) } + inactive = { 'current': gb(s['inactive_split_bytes.all.current']), 'peak': gb(s['inactive_split_bytes.all.peak']) } + warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] } + cuda = { + 'system': system, + 'active': active, + 'allocated': allocated, + 'reserved': reserved, + 'inactive': inactive, + 'events': warnings, + } + else: + cuda = { 'error': 'unavailable' } + except Exception as err: + cuda = { 'error': f'{err}' } + return MemoryResponse(ram = ram, cuda = cuda) + def launch(self, server_name, port): self.app.include_router(self.router) uvicorn.run(self.app, host=server_name, port=port) diff --git a/modules/api/models.py b/modules/api/models.py index 5fa63774..49bf1e7a 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -260,3 +260,7 @@ class EmbeddingItem(BaseModel): class EmbeddingsResponse(BaseModel): loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model") skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") + +class MemoryResponse(BaseModel): + ram: dict[str, str] | dict[str, float] = Field(title="RAM", description="System memory stats") + cuda: dict[str, str] | dict[str, dict] = Field(title="CUDA", description="nVidia CUDA memory stats") -- cgit v1.2.3 From d38ede71d5330958f4bbac5f99c1be3c146b506a Mon Sep 17 00:00:00 2001 From: noodleanon <122053346+noodleanon@users.noreply.github.com> Date: Sat, 7 Jan 2023 14:21:31 +0000 Subject: Added script support in txt2img endpoint --- modules/api/api.py | 22 +++++++++++++++++++--- modules/api/models.py | 2 +- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index aa62a42e..0e8ea263 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -149,6 +149,14 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + if txt2imgreq.script_name is not None: + if scripts.scripts_txt2img.scripts == []: + scripts.scripts_txt2img.initialize_scripts(True) + ui.create_ui() + + script_idx = script_name_to_index(txt2imgreq.script_name, scripts.scripts_txt2img.selectable_scripts) + script = scripts.scripts_txt2img.selectable_scripts[script_idx] + populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), "do_not_save_samples": True, @@ -158,11 +166,20 @@ class Api: if populate.sampler_name: populate.sampler_index = None # prevent a warning later on + args = vars(populate) + args.pop('script_name', None) + with self.queue_lock: - p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **vars(populate)) + p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) shared.state.begin() - processed = process_images(p) + if 'script' in locals(): + p.outpath_grids = opts.outdir_txt2img_grids + p.outpath_samples = opts.outdir_txt2img_samples + p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args + processed = scripts.scripts_txt2img.run(p, *p.script_args) + else: + processed = process_images(p) shared.state.end() @@ -213,7 +230,6 @@ class Api: processed = scripts.scripts_img2img.run(p, *p.script_args) else: processed = process_images(p) - shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/api/models.py b/modules/api/models.py index c85eb94d..ce43c858 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -100,7 +100,7 @@ class PydanticModelGenerator: StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator( "StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img, - [{"key": "sampler_index", "type": str, "default": "Euler"}] + [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "script_name", "type": str, "default": None}, {"key": "script_args", "type": list, "default": []}] ).generate_model() StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( -- cgit v1.2.3 From cabd95015b1085e989d9655ea805dbe5e33f5286 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sat, 7 Jan 2023 19:18:42 +0300 Subject: fix quicksettings name overlap --- style.css | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/style.css b/style.css index 76721756..d796cbe9 100644 --- a/style.css +++ b/style.css @@ -512,7 +512,7 @@ input[type="range"]{ border: none; background: none; flex: unset; - gap: 0.5em; + gap: 1em; } #quicksettings > div > div{ @@ -521,6 +521,17 @@ input[type="range"]{ padding: 0; } +#quicksettings > div > div > div > div > label > span { + position: relative; + margin-right: 9em; + margin-bottom: -1em; +} + +#quicksettings > div > div > label > span { + position: relative; + margin-bottom: -1em; +} + canvas[key="mask"] { z-index: 12 !important; filter: invert(); -- cgit v1.2.3 From 448b9cedab66e05b5b2800513ca334a769b42aa7 Mon Sep 17 00:00:00 2001 From: dan Date: Sat, 7 Jan 2023 21:07:27 +0800 Subject: Allow variable img size --- modules/textual_inversion/dataset.py | 18 +++++++++++------- modules/textual_inversion/textual_inversion.py | 4 ++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 88d68c76..375178ed 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None): self.filename = filename self.filename_text = filename_text self.latent_dist = latent_dist @@ -25,6 +25,7 @@ class DatasetEntry: self.cond = cond self.cond_text = cond_text self.pixel_values = pixel_values + self.img_shape = img_shape class PersonalizedBase(Dataset): @@ -33,8 +34,6 @@ class PersonalizedBase(Dataset): self.placeholder_token = placeholder_token - self.width = width - self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] @@ -59,7 +58,11 @@ class PersonalizedBase(Dataset): if shared.state.interrupted: raise Exception("interrupted") try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + image = Image.open(path).convert('RGB') + if width < 2000: + image = image.resize((width, height), PIL.Image.BICUBIC) + else: + assert batch_size == 1, 'variable img size must have batch size 1' except Exception: continue @@ -88,14 +91,14 @@ class PersonalizedBase(Dataset): if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) elif latent_sampling_method == "deterministic": # Works only for DiagonalGaussianDistribution latent_dist.std = 0 latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -151,6 +154,7 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + self.img_shape = [entry.img_shape for entry in data] #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 45882ed6..9f96d0fd 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -451,8 +451,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ else: p.prompt = batch.cond_text[0] p.steps = 20 - p.width = training_width - p.height = training_height + p.width = batch.img_shape[0][0] + p.height = batch.img_shape[0][1] preview_text = p.prompt -- cgit v1.2.3 From 669fb18d5222f53ae48abe0f30393d846c50ad91 Mon Sep 17 00:00:00 2001 From: dan Date: Sun, 8 Jan 2023 01:34:52 +0800 Subject: Add checkbox for variable training dims --- modules/hypernetworks/hypernetwork.py | 2 +- modules/textual_inversion/dataset.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/ui.py | 3 +++ 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b0cfbe71..dba52841 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -403,7 +403,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 375178ed..7f8a314f 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -29,7 +29,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token @@ -59,7 +59,7 @@ class PersonalizedBase(Dataset): raise Exception("interrupted") try: image = Image.open(path).convert('RGB') - if width < 2000: + if not varsize: image = image.resize((width, height), PIL.Image.BICUBIC) else: assert batch_size == 1, 'variable img size must have batch size 1' diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 9f96d0fd..110efd19 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -255,7 +255,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") @@ -309,7 +309,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) diff --git a/modules/ui.py b/modules/ui.py index 99483130..4e709a71 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1343,6 +1343,7 @@ def create_ui(): template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Ignore dimension settings and do not resize images", value=False, elem_id="train_varsize") steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") with FormRow(): @@ -1449,6 +1450,7 @@ def create_ui(): log_directory, training_width, training_height, + varsize, steps, clip_grad_mode, clip_grad_value, @@ -1480,6 +1482,7 @@ def create_ui(): log_directory, training_width, training_height, + varsize, steps, clip_grad_mode, clip_grad_value, -- cgit v1.2.3 From 72497895b9b1948f86d9309fe897cbb70c20ba7e Mon Sep 17 00:00:00 2001 From: dan Date: Sun, 8 Jan 2023 01:36:00 +0800 Subject: Move batchsize check --- modules/hypernetworks/hypernetwork.py | 2 +- modules/textual_inversion/dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index dba52841..32c67ccc 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -456,7 +456,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, pin_memory = shared.opts.pin_memory - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: saved_params = dict( diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7f8a314f..bcad6848 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -46,6 +46,8 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" + if varsize: + assert batch_size == 1, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] @@ -61,8 +63,6 @@ class PersonalizedBase(Dataset): image = Image.open(path).convert('RGB') if not varsize: image = image.resize((width, height), PIL.Image.BICUBIC) - else: - assert batch_size == 1, 'variable img size must have batch size 1' except Exception: continue -- cgit v1.2.3 From 984b86dd0abf0da7f6b116864c791a2bfe8859ef Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 7 Jan 2023 13:08:21 -0700 Subject: Add fallback for Protocol import --- modules/sub_quadratic_attention.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index fea7aaac..93381bae 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,7 +15,13 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math -from typing import Optional, NamedTuple, Protocol, List + +try: + from typing import Protocol +except: + from typing_extensions import Protocol + +from typing import Optional, NamedTuple, List def narrow_trunc( input: Tensor, -- cgit v1.2.3 From a0c87f1fdf2b76b2ae4ef6c4b01ddaede3afab06 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 8 Jan 2023 08:52:26 +0300 Subject: skip images in embeddings dir if they have a second .preview extension --- modules/textual_inversion/textual_inversion.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 45882ed6..e85dd549 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -109,6 +109,10 @@ class EmbeddingDatabase: ext = ext.upper() if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': + return + embed_image = Image.open(path) if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) -- cgit v1.2.3 From 085427de0efc9e9e7a6e9a5aebc6b5a69f0365e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 8 Jan 2023 09:37:33 +0300 Subject: make it possible for extensions/scripts to add their own embedding directories --- modules/sd_hijack.py | 7 +- modules/textual_inversion/textual_inversion.py | 170 +++++++++++++++---------- 2 files changed, 108 insertions(+), 69 deletions(-) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cfdb09d6..6b0d95af 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -83,10 +83,12 @@ class StableDiffusionModelHijack: clip = None optimization_method = None - embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() - def hijack(self, m): + def __init__(self): + self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) + def hijack(self, m): if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) @@ -117,7 +119,6 @@ class StableDiffusionModelHijack: self.layers = flatten(m) def undo_hijack(self, m): - if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: m.cond_stage_model = m.cond_stage_model.wrapped diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e85dd549..217fe9eb 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -66,17 +66,41 @@ class Embedding: return self.cached_checksum +class DirWithTextualInversionEmbeddings: + def __init__(self, path): + self.path = path + self.mtime = None + + def has_changed(self): + if not os.path.isdir(self.path): + return False + + mt = os.path.getmtime(self.path) + if self.mtime is None or mt > self.mtime: + return True + + def update(self): + if not os.path.isdir(self.path): + return + + self.mtime = os.path.getmtime(self.path) + + class EmbeddingDatabase: - def __init__(self, embeddings_dir): + def __init__(self): self.ids_lookup = {} self.word_embeddings = {} self.skipped_embeddings = {} - self.dir_mtime = None - self.embeddings_dir = embeddings_dir self.expected_shape = -1 + self.embedding_dirs = {} - def register_embedding(self, embedding, model): + def add_embedding_dir(self, path): + self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path) + + def clear_embedding_dirs(self): + self.embedding_dirs.clear() + def register_embedding(self, embedding, model): self.word_embeddings[embedding.name] = embedding ids = model.cond_stage_model.tokenize([embedding.name])[0] @@ -93,69 +117,62 @@ class EmbeddingDatabase: vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) return vec.shape[1] - def load_textual_inversion_embeddings(self, force_reload = False): - mt = os.path.getmtime(self.embeddings_dir) - if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime: - return + def load_from_file(self, path, filename): + name, ext = os.path.splitext(filename) + ext = ext.upper() - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() - self.skipped_embeddings.clear() - self.expected_shape = self.get_expected_shape() - - def process_file(path, filename): - name, ext = os.path.splitext(filename) - ext = ext.upper() - - if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: - _, second_ext = os.path.splitext(name) - if second_ext.upper() == '.PREVIEW': - return - - embed_image = Image.open(path) - if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: - data = embedding_from_b64(embed_image.text['sd-ti-embedding']) - name = data.get('name', name) - else: - data = extract_image_data_embed(embed_image) - name = data.get('name', name) - elif ext in ['.BIN', '.PT']: - data = torch.load(path, map_location="cpu") - else: + if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: + _, second_ext = os.path.splitext(name) + if second_ext.upper() == '.PREVIEW': return - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - else: - raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - - vec = emb.detach().to(devices.device, dtype=torch.float32) - embedding = Embedding(vec, name) - embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('sd_checkpoint', None) - embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - embedding.vectors = vec.shape[0] - embedding.shape = vec.shape[-1] - - if self.expected_shape == -1 or self.expected_shape == embedding.shape: - self.register_embedding(embedding, shared.sd_model) + embed_image = Image.open(path) + if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: + data = embedding_from_b64(embed_image.text['sd-ti-embedding']) + name = data.get('name', name) else: - self.skipped_embeddings[name] = embedding + data = extract_image_data_embed(embed_image) + name = data.get('name', name) + elif ext in ['.BIN', '.PT']: + data = torch.load(path, map_location="cpu") + else: + return + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vec.shape[0] + embedding.shape = vec.shape[-1] + + if self.expected_shape == -1 or self.expected_shape == embedding.shape: + self.register_embedding(embedding, shared.sd_model) + else: + self.skipped_embeddings[name] = embedding - for root, dirs, fns in os.walk(self.embeddings_dir): + def load_from_dir(self, embdir): + if not os.path.isdir(embdir.path): + return + + for root, dirs, fns in os.walk(embdir.path): for fn in fns: try: fullfn = os.path.join(root, fn) @@ -163,12 +180,32 @@ class EmbeddingDatabase: if os.stat(fullfn).st_size == 0: continue - process_file(fullfn, fn) + self.load_from_file(fullfn, fn) except Exception: print(f"Error loading embedding {fn}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) continue + def load_textual_inversion_embeddings(self, force_reload=False): + if not force_reload: + need_reload = False + for path, embdir in self.embedding_dirs.items(): + if embdir.has_changed(): + need_reload = True + break + + if not need_reload: + return + + self.ids_lookup.clear() + self.word_embeddings.clear() + self.skipped_embeddings.clear() + self.expected_shape = self.get_expected_shape() + + for path, embdir in self.embedding_dirs.items(): + self.load_from_dir(embdir) + embdir.update() + print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") @@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert os.path.isfile(template_file), "Prompt template file doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" - assert steps > 0 , "Max steps must be positive" + assert steps > 0, "Max steps must be positive" assert isinstance(save_model_every, int), "Save {name} must be integer" - assert save_model_every >= 0 , "Save {name} must be positive or 0" + assert save_model_every >= 0, "Save {name} must be positive or 0" assert isinstance(create_image_every, int), "Create image must be integer" - assert create_image_every >= 0 , "Create image must be positive or 0" + assert create_image_every >= 0, "Create image must be positive or 0" if save_model_every or create_image_every: assert log_directory, "Log directory is empty" + def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 -- cgit v1.2.3 From 1aca26816eb63adfc7ec798b15a479c3575cc6d4 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sun, 8 Jan 2023 13:07:35 +0300 Subject: use actual label for feature requests --- .github/ISSUE_TEMPLATE/feature_request.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index 8ca6e21f..35a88740 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -1,7 +1,7 @@ name: Feature request description: Suggest an idea for this project title: "[Feature Request]: " -labels: ["suggestion"] +labels: ["enhancement"] body: - type: checkboxes -- cgit v1.2.3 From 6d0cc1e239e0a43a2e6d696eae20c66fad0819bb Mon Sep 17 00:00:00 2001 From: noodleanon <122053346+noodleanon@users.noreply.github.com> Date: Sun, 8 Jan 2023 11:03:48 +0000 Subject: Corrected is_img2img param --- modules/api/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 0e8ea263..1785a6b4 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -151,7 +151,7 @@ class Api: def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): if txt2imgreq.script_name is not None: if scripts.scripts_txt2img.scripts == []: - scripts.scripts_txt2img.initialize_scripts(True) + scripts.scripts_txt2img.initialize_scripts(False) ui.create_ui() script_idx = script_name_to_index(txt2imgreq.script_name, scripts.scripts_txt2img.selectable_scripts) -- cgit v1.2.3 From 137ce534b2355a527cd1a50c192909161258b442 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 8 Jan 2023 16:14:38 +0300 Subject: remove some code duplication remove calls to locals() add a test for img2img with script --- modules/api/api.py | 33 ++++++++++++++++----------------- test/basic_features/img2img_test.py | 6 ++++++ 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 1785a6b4..5b6125f8 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -148,14 +148,20 @@ class Api: raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Basic"}) - def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - if txt2imgreq.script_name is not None: - if scripts.scripts_txt2img.scripts == []: - scripts.scripts_txt2img.initialize_scripts(False) - ui.create_ui() + def get_script(self, script_name, script_runner): + if script_name is None: + return None, None + + if not script_runner.scripts: + script_runner.initialize_scripts(False) + ui.create_ui() + + script_idx = script_name_to_index(script_name, script_runner.selectable_scripts) + script = script_runner.selectable_scripts[script_idx] + return script, script_idx - script_idx = script_name_to_index(txt2imgreq.script_name, scripts.scripts_txt2img.selectable_scripts) - script = scripts.scripts_txt2img.selectable_scripts[script_idx] + def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): + script, script_idx = self.get_script(txt2imgreq.script_name, scripts.scripts_txt2img) populate = txt2imgreq.copy(update={ # Override __init__ params "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), @@ -173,7 +179,7 @@ class Api: p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) shared.state.begin() - if 'script' in locals(): + if script is not None: p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args @@ -182,7 +188,6 @@ class Api: processed = process_images(p) shared.state.end() - b64images = list(map(encode_pil_to_base64, processed.images)) return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) @@ -192,13 +197,7 @@ class Api: if init_images is None: raise HTTPException(status_code=404, detail="Init image not found") - if img2imgreq.script_name is not None: - if scripts.scripts_img2img.scripts == []: - scripts.scripts_img2img.initialize_scripts(True) - ui.create_ui() - - script_idx = script_name_to_index(img2imgreq.script_name, scripts.scripts_img2img.selectable_scripts) - script = scripts.scripts_img2img.selectable_scripts[script_idx] + script, script_idx = self.get_script(img2imgreq.script_name, scripts.scripts_img2img) mask = img2imgreq.mask if mask: @@ -223,7 +222,7 @@ class Api: p.init_images = [decode_base64_to_image(x) for x in init_images] shared.state.begin() - if 'script' in locals(): + if script is not None: p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples p.script_args = [script_idx + 1] + [None] * (script.args_from - 1) + p.script_args diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py index 0a9c1e8a..bd520b13 100644 --- a/test/basic_features/img2img_test.py +++ b/test/basic_features/img2img_test.py @@ -50,6 +50,12 @@ class TestImg2ImgWorking(unittest.TestCase): self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + def test_img2img_sd_upscale_performed(self): + self.simple_img2img["script_name"] = "sd upscale" + self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0] + + self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + if __name__ == "__main__": unittest.main() -- cgit v1.2.3 From cb255faec6e5f6b47b7632e6b7d450b9e2f6678b Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Sun, 8 Jan 2023 10:17:50 -0700 Subject: Add support for loading VAEs from safetensor files --- modules/sd_vae.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index ac71d62d..9fcfd9db 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,4 +1,5 @@ import torch +import safetensors.torch import os import collections from collections import namedtuple @@ -72,8 +73,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): candidates = [ *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), + *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True) + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), ] if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): candidates.append(shared.cmd_opts.vae_path) @@ -137,6 +140,12 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"): if os.path.isfile(vae_file_try): vae_file = vae_file_try print(f"Using VAE found similar to selected model: {vae_file}") + # if still not found, try look for ".vae.safetensors" beside model + if vae_file == "auto": + vae_file_try = model_path + ".vae.safetensors" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print(f"Using VAE found similar to selected model: {vae_file}") # No more fallbacks for auto if vae_file == "auto": vae_file = None @@ -163,8 +172,14 @@ def load_vae(model, vae_file=None): assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") store_base_vae(model) - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + _, extension = os.path.splitext(vae_file) + if extension.lower() == ".safetensors": + vae_ckpt = safetensors.torch.load_file(vae_file, device=shared.weight_load_location) + else: + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + if "state_dict" in vae_ckpt: + vae_ckpt = vae_ckpt["state_dict"] + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, vae_dict_1) if cache_enabled: -- cgit v1.2.3 From 1d663a04da900b79132063e94c94ab379ebd14a8 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 9 Jan 2023 14:11:37 +0300 Subject: make tests runnable without specifying subdirectory --- test/server_poll.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/server_poll.py b/test/server_poll.py index d4df697b..42d56a4c 100644 --- a/test/server_poll.py +++ b/test/server_poll.py @@ -15,7 +15,7 @@ def run_tests(proc, test_dir): break if proc.poll() is None: if test_dir is None: - test_dir = "" + test_dir = "test" suite = unittest.TestLoader().discover(test_dir, pattern="*_test.py", top_level_dir="test") result = unittest.TextTestRunner(verbosity=2).run(suite) return len(result.failures) + len(result.errors) -- cgit v1.2.3 From 3af488bdff983efc8e77f49b26c18847413754f4 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 9 Jan 2023 14:29:28 +0300 Subject: try all tests --- .github/workflows/run_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 49dc92bd..110a6a75 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -20,7 +20,7 @@ jobs: key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} restore-keys: ${{ runner.os }}-pip- - name: Run tests - run: python launch.py --tests basic_features --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test + run: python launch.py --tests --no-half --disable-opt-split-attention --use-cpu all --skip-torch-cuda-test - name: Upload main app stdout-stderr uses: actions/upload-artifact@v3 if: always() -- cgit v1.2.3 From d4fd2418efb0986a8226add0b800fb5c73ffb58c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 14:57:47 +0300 Subject: add an option to use old hiresfix width/height behavior add a visual effect to inactive hires fix elements --- javascript/hires_fix.js | 25 +++++++++++++++++++++++++ modules/generation_parameters_copypaste.py | 17 +++++++++++------ modules/processing.py | 26 ++++++++++++++++++++++++-- modules/shared.py | 1 + modules/ui.py | 23 ++++++++++++++--------- style.css | 4 ++++ 6 files changed, 79 insertions(+), 17 deletions(-) create mode 100644 javascript/hires_fix.js diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js new file mode 100644 index 00000000..07fba549 --- /dev/null +++ b/javascript/hires_fix.js @@ -0,0 +1,25 @@ + +function setInactive(elem, inactive){ + console.log(elem) + if(inactive){ + elem.classList.add('inactive') + } else{ + elem.classList.remove('inactive') + } +} + +function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ + console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) + + hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') + hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') + hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y') + + gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : "" + + setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0) + setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0) + setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0) + + return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y] +} diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 12a9de3d..f7f68b67 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -197,6 +197,15 @@ def restore_old_hires_fix_params(res): firstpass_width = res.get('First pass size-1', None) firstpass_height = res.get('First pass size-2', None) + if shared.opts.use_old_hires_fix_width_height: + hires_width = int(res.get("Hires resize-1", None)) + hires_height = int(res.get("Hires resize-2", None)) + + if hires_width is not None and hires_height is not None: + res['Size-1'] = hires_width + res['Size-2'] = hires_height + return + if firstpass_width is None or firstpass_height is None: return @@ -205,12 +214,8 @@ def restore_old_hires_fix_params(res): height = int(res.get("Size-2", 512)) if firstpass_width == 0 or firstpass_height == 0: - # old algorithm for auto-calculating first pass size - desired_pixel_count = 512 * 512 - actual_pixel_count = width * height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) - firstpass_width = math.ceil(scale * width / 64) * 64 - firstpass_height = math.ceil(scale * height / 64) * 64 + from modules import processing + firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) res['Size-1'] = firstpass_width res['Size-2'] = firstpass_height diff --git a/modules/processing.py b/modules/processing.py index 1d23b15f..f04a0e1e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -687,6 +687,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: return res +def old_hires_fix_first_pass_dimensions(width, height): + """old algorithm for auto-calculating first pass size""" + + desired_pixel_count = 512 * 512 + actual_pixel_count = width * height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + width = math.ceil(scale * width / 64) * 64 + height = math.ceil(scale * height / 64) * 64 + + return width, height + + class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None @@ -703,16 +715,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_upscale_to_y = hr_resize_y if firstphase_width != 0 or firstphase_height != 0: - print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) - self.hr_scale = self.width / firstphase_width + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height self.width = firstphase_width self.height = firstphase_height self.truncate_x = 0 self.truncate_y = 0 + self.applied_old_hires_behavior_to = None def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): + self.hr_resize_x = self.width + self.hr_resize_y = self.height + self.hr_upscale_to_x = self.width + self.hr_upscale_to_y = self.height + + self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height) + self.applied_old_hires_behavior_to = (self.width, self.height) + if self.hr_resize_x == 0 and self.hr_resize_y == 0: self.extra_generation_params["Hires upscale"] = self.hr_scale self.hr_upscale_to_x = int(self.width * self.hr_scale) diff --git a/modules/shared.py b/modules/shared.py index a6712dae..a1e10201 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -398,6 +398,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), + "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { diff --git a/modules/ui.py b/modules/ui.py index 99483130..719c26b3 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -267,7 +267,7 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz with devices.autocast(): p.init([""], [0], [0]) - return f"resize: from {width}x{height} to {p.hr_upscale_to_x}x{p.hr_upscale_to_y}" + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" def apply_styles(prompt, prompt_neg, style1_name, style2_name): @@ -745,15 +745,20 @@ def create_ui(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - hr_resolution_preview_args = dict( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False - ) - for input in hr_resolution_preview_inputs: - input.change(**hr_resolution_preview_args) + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) diff --git a/style.css b/style.css index d796cbe9..ec5e4182 100644 --- a/style.css +++ b/style.css @@ -670,6 +670,10 @@ footer { min-width: auto; } +.inactive{ + opacity: 0.5; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 7d2bb86cce10ee6a8e81aaad810544a4ca38cec9 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 9 Jan 2023 19:39:06 +0300 Subject: combine tests together, return set options test --- test/advanced_features/__init__.py | 0 test/advanced_features/extras_test.py | 29 --------------------- test/advanced_features/txt2img_test.py | 47 ---------------------------------- test/basic_features/extras_test.py | 29 +++++++++++++++++++++ test/basic_features/txt2img_test.py | 4 +++ test/basic_features/utils_test.py | 14 ++++++++++ 6 files changed, 47 insertions(+), 76 deletions(-) delete mode 100644 test/advanced_features/__init__.py delete mode 100644 test/advanced_features/extras_test.py delete mode 100644 test/advanced_features/txt2img_test.py create mode 100644 test/basic_features/extras_test.py diff --git a/test/advanced_features/__init__.py b/test/advanced_features/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/advanced_features/extras_test.py b/test/advanced_features/extras_test.py deleted file mode 100644 index 8763f8ed..00000000 --- a/test/advanced_features/extras_test.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest - - -class TestExtrasWorking(unittest.TestCase): - def setUp(self): - self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image" - self.simple_extras = { - "resize_mode": 0, - "show_extras_results": True, - "gfpgan_visibility": 0, - "codeformer_visibility": 0, - "codeformer_weight": 0, - "upscaling_resize": 2, - "upscaling_resize_w": 128, - "upscaling_resize_h": 128, - "upscaling_crop": True, - "upscaler_1": "None", - "upscaler_2": "None", - "extras_upscaler_2_visibility": 0, - "image": "" - } - - -class TestExtrasCorrectness(unittest.TestCase): - pass - - -if __name__ == "__main__": - unittest.main() diff --git a/test/advanced_features/txt2img_test.py b/test/advanced_features/txt2img_test.py deleted file mode 100644 index 36ed7b9a..00000000 --- a/test/advanced_features/txt2img_test.py +++ /dev/null @@ -1,47 +0,0 @@ -import unittest -import requests - - -class TestTxt2ImgWorking(unittest.TestCase): - def setUp(self): - self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img" - self.simple_txt2img = { - "enable_hr": False, - "denoising_strength": 0, - "firstphase_width": 0, - "firstphase_height": 0, - "prompt": "example prompt", - "styles": [], - "seed": -1, - "subseed": -1, - "subseed_strength": 0, - "seed_resize_from_h": -1, - "seed_resize_from_w": -1, - "batch_size": 1, - "n_iter": 1, - "steps": 3, - "cfg_scale": 7, - "width": 64, - "height": 64, - "restore_faces": False, - "tiling": False, - "negative_prompt": "", - "eta": 0, - "s_churn": 0, - "s_tmax": 0, - "s_tmin": 0, - "s_noise": 1, - "sampler_index": "Euler a" - } - - def test_txt2img_with_restore_faces_performed(self): - self.simple_txt2img["restore_faces"] = True - self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - - -class TestTxt2ImgCorrectness(unittest.TestCase): - pass - - -if __name__ == "__main__": - unittest.main() diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py new file mode 100644 index 00000000..8763f8ed --- /dev/null +++ b/test/basic_features/extras_test.py @@ -0,0 +1,29 @@ +import unittest + + +class TestExtrasWorking(unittest.TestCase): + def setUp(self): + self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image" + self.simple_extras = { + "resize_mode": 0, + "show_extras_results": True, + "gfpgan_visibility": 0, + "codeformer_visibility": 0, + "codeformer_weight": 0, + "upscaling_resize": 2, + "upscaling_resize_w": 128, + "upscaling_resize_h": 128, + "upscaling_crop": True, + "upscaler_1": "None", + "upscaler_2": "None", + "extras_upscaler_2_visibility": 0, + "image": "" + } + + +class TestExtrasCorrectness(unittest.TestCase): + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py index 1c2674b2..bbc846ed 100644 --- a/test/basic_features/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -63,6 +63,10 @@ class TestTxt2ImgWorking(unittest.TestCase): self.simple_txt2img["n_iter"] = 2 self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + def test_txt2img_with_restore_faces_performed(self): + self.simple_txt2img["restore_faces"] = True + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + if __name__ == "__main__": unittest.main() diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py index 765470c9..b3c4045a 100644 --- a/test/basic_features/utils_test.py +++ b/test/basic_features/utils_test.py @@ -18,6 +18,20 @@ class UtilsTests(unittest.TestCase): def test_options_get(self): self.assertEqual(requests.get(self.url_options).status_code, 200) + def test_options_write(self): + response = requests.get(self.url_options) + self.assertEqual(response.status_code, 200) + + pre_value = response.json()["send_seed"] + + self.assertEqual(requests.post(self.url_options, json={"send_seed":not pre_value}).status_code, 200) + + response = requests.get(self.url_options) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["send_seed"], not pre_value) + + requests.post(self.url_options, json={"send_seed": pre_value}) + def test_cmd_flags(self): self.assertEqual(requests.get(self.url_cmd_flags).status_code, 200) -- cgit v1.2.3 From 49c4509ce2302350210ff650fd26373518c46a79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 19:58:35 +0300 Subject: use existing function for loading VAE weights from file --- modules/sd_vae.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9fcfd9db..0a49daa1 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,7 +3,7 @@ import safetensors.torch import os import collections from collections import namedtuple -from modules import shared, devices, script_callbacks +from modules import shared, devices, script_callbacks, sd_models from modules.paths import models_path import glob from copy import deepcopy @@ -172,13 +172,8 @@ def load_vae(model, vae_file=None): assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") store_base_vae(model) - _, extension = os.path.splitext(vae_file) - if extension.lower() == ".safetensors": - vae_ckpt = safetensors.torch.load_file(vae_file, device=shared.weight_load_location) - else: - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - if "state_dict" in vae_ckpt: - vae_ckpt = vae_ckpt["state_dict"] + + vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, vae_dict_1) @@ -210,10 +205,12 @@ def _load_vae_dict(model, vae_dict_1): model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) + def clear_loaded_vae(): global loaded_vae_file loaded_vae_file = None + def reload_vae_weights(sd_model=None, vae_file="auto"): from modules import lowvram, devices, sd_hijack -- cgit v1.2.3 From cdfcbd995932ffa728db0cc00a5f97665c752103 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 20:08:48 +0300 Subject: Remove fallback for Protocol import and remove Protocol import and remove instances of Protocol in code add some whitespace between functions to be in line with other code in the repo --- modules/sub_quadratic_attention.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 93381bae..55052815 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,14 +15,9 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math - -try: - from typing import Protocol -except: - from typing_extensions import Protocol - from typing import Optional, NamedTuple, List + def narrow_trunc( input: Tensor, dim: int, @@ -31,12 +26,14 @@ def narrow_trunc( ) -> Tensor: return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) + class AttnChunk(NamedTuple): exp_values: Tensor exp_weights_sum: Tensor max_score: Tensor -class SummarizeChunk(Protocol): + +class SummarizeChunk: @staticmethod def __call__( query: Tensor, @@ -44,7 +41,8 @@ class SummarizeChunk(Protocol): value: Tensor, ) -> AttnChunk: ... -class ComputeQueryChunkAttn(Protocol): + +class ComputeQueryChunkAttn: @staticmethod def __call__( query: Tensor, @@ -52,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol): value: Tensor, ) -> Tensor: ... + def _summarize_chunk( query: Tensor, key: Tensor, @@ -72,6 +71,7 @@ def _summarize_chunk( max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + def _query_chunk_attention( query: Tensor, key: Tensor, @@ -112,6 +112,7 @@ def _query_chunk_attention( all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) return all_values / all_weights + # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( query: Tensor, @@ -131,10 +132,12 @@ def _get_attention_scores_no_kv_chunking( hidden_states_slice = torch.bmm(attn_probs, value) return hidden_states_slice + class ScannedChunk(NamedTuple): chunk_idx: int attn_chunk: AttnChunk + def efficient_dot_product_attention( query: Tensor, key: Tensor, -- cgit v1.2.3 From 00005ac9af10d58a75f7ce0aa04db78775808e93 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 9 Jan 2023 21:01:28 +0300 Subject: add more tests --- test/basic_features/extras_test.py | 37 +++++++++++++++++++++++++++++++------ test/basic_features/img2img_test.py | 7 ++++++- test/basic_features/txt2img_test.py | 11 +++++++++-- test/basic_features/utils_test.py | 3 +++ 4 files changed, 49 insertions(+), 9 deletions(-) diff --git a/test/basic_features/extras_test.py b/test/basic_features/extras_test.py index 8763f8ed..0170c511 100644 --- a/test/basic_features/extras_test.py +++ b/test/basic_features/extras_test.py @@ -1,10 +1,12 @@ import unittest - +import requests +from gradio.processing_utils import encode_pil_to_base64 +from PIL import Image class TestExtrasWorking(unittest.TestCase): def setUp(self): - self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image" - self.simple_extras = { + self.url_extras_single = "http://localhost:7860/sdapi/v1/extra-single-image" + self.extras_single = { "resize_mode": 0, "show_extras_results": True, "gfpgan_visibility": 0, @@ -17,12 +19,35 @@ class TestExtrasWorking(unittest.TestCase): "upscaler_1": "None", "upscaler_2": "None", "extras_upscaler_2_visibility": 0, - "image": "" + "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) } + def test_simple_upscaling_performed(self): + self.extras_single["upscaler_1"] = "Lanczos" + self.assertEqual(requests.post(self.url_extras_single, json=self.extras_single).status_code, 200) + + +class TestPngInfoWorking(unittest.TestCase): + def setUp(self): + self.url_png_info = "http://localhost:7860/sdapi/v1/extra-single-image" + self.png_info = { + "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")) + } + + def test_png_info_performed(self): + self.assertEqual(requests.post(self.url_png_info, json=self.png_info).status_code, 200) + + +class TestInterrogateWorking(unittest.TestCase): + def setUp(self): + self.url_interrogate = "http://localhost:7860/sdapi/v1/extra-single-image" + self.interrogate = { + "image": encode_pil_to_base64(Image.open(r"test/test_files/img2img_basic.png")), + "model": "clip" + } -class TestExtrasCorrectness(unittest.TestCase): - pass + def test_interrogate_performed(self): + self.assertEqual(requests.post(self.url_interrogate, json=self.interrogate).status_code, 200) if __name__ == "__main__": diff --git a/test/basic_features/img2img_test.py b/test/basic_features/img2img_test.py index bd520b13..08c5c903 100644 --- a/test/basic_features/img2img_test.py +++ b/test/basic_features/img2img_test.py @@ -16,7 +16,7 @@ class TestImg2ImgWorking(unittest.TestCase): "inpainting_fill": 0, "inpaint_full_res": False, "inpaint_full_res_padding": 0, - "inpainting_mask_invert": 0, + "inpainting_mask_invert": False, "prompt": "example prompt", "styles": [], "seed": -1, @@ -50,6 +50,11 @@ class TestImg2ImgWorking(unittest.TestCase): self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + def test_inpainting_with_inverted_masked_performed(self): + self.simple_img2img["mask"] = encode_pil_to_base64(Image.open(r"test/test_files/mask_basic.png")) + self.simple_img2img["inpainting_mask_invert"] = True + self.assertEqual(requests.post(self.url_img2img, json=self.simple_img2img).status_code, 200) + def test_img2img_sd_upscale_performed(self): self.simple_img2img["script_name"] = "sd upscale" self.simple_img2img["script_args"] = ["", 8, "Lanczos", 2.0] diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py index bbc846ed..5b27a7ec 100644 --- a/test/basic_features/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -41,6 +41,9 @@ class TestTxt2ImgWorking(unittest.TestCase): self.simple_txt2img["negative_prompt"] = "example negative prompt" self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + def test_txt2img_with_complex_prompt_performed(self): + self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]" + def test_txt2img_not_square_image_performed(self): self.simple_txt2img["height"] = 128 self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) @@ -53,6 +56,10 @@ class TestTxt2ImgWorking(unittest.TestCase): self.simple_txt2img["tiling"] = True self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + def test_txt2img_with_restore_faces_performed(self): + self.simple_txt2img["restore_faces"] = True + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) + def test_txt2img_with_vanilla_sampler_performed(self): self.simple_txt2img["sampler_index"] = "PLMS" self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) @@ -63,8 +70,8 @@ class TestTxt2ImgWorking(unittest.TestCase): self.simple_txt2img["n_iter"] = 2 self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) - def test_txt2img_with_restore_faces_performed(self): - self.simple_txt2img["restore_faces"] = True + def test_txt2img_batch_performed(self): + self.simple_txt2img["batch_size"] = 2 self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py index b3c4045a..94e00253 100644 --- a/test/basic_features/utils_test.py +++ b/test/basic_features/utils_test.py @@ -14,6 +14,7 @@ class UtilsTests(unittest.TestCase): self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles" self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories" self.url_artists = "http://localhost:7860/sdapi/v1/artists" + self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings" def test_options_get(self): self.assertEqual(requests.get(self.url_options).status_code, 200) @@ -62,6 +63,8 @@ class UtilsTests(unittest.TestCase): def test_artists(self): self.assertEqual(requests.get(self.url_artists).status_code, 200) + def test_embeddings(self): + self.assertEqual(requests.get(self.url_artists).status_code, 200) if __name__ == "__main__": unittest.main() -- cgit v1.2.3 From 43bb5190fc9e7ae479a5dc6640be202c9a71e464 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 22:52:23 +0300 Subject: remove/simplify some changes from #6481 --- modules/textual_inversion/dataset.py | 14 +++++--------- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/ui.py | 2 +- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index bcad6848..fa48708e 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): self.filename = filename self.filename_text = filename_text self.latent_dist = latent_dist @@ -25,7 +25,6 @@ class DatasetEntry: self.cond = cond self.cond_text = cond_text self.pixel_values = pixel_values - self.img_shape = img_shape class PersonalizedBase(Dataset): @@ -46,12 +45,10 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - if varsize: - assert batch_size == 1, 'variable img size must have batch size 1' + assert batch_size == 1 or not varsize, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] - self.shuffle_tags = shuffle_tags self.tag_drop_out = tag_drop_out @@ -91,14 +88,14 @@ class PersonalizedBase(Dataset): if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) elif latent_sampling_method == "deterministic": # Works only for DiagonalGaussianDistribution latent_dist.std = 0 latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -154,7 +151,6 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) - self.img_shape = [entry.img_shape for entry in data] #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ad76297e..14be2c96 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -492,8 +492,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ else: p.prompt = batch.cond_text[0] p.steps = 20 - p.width = batch.img_shape[0][0] - p.height = batch.img_shape[0][1] + p.width = training_width + p.height = training_height preview_text = p.prompt diff --git a/modules/ui.py b/modules/ui.py index 9d6b141e..ddfe1b1a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1348,7 +1348,7 @@ def create_ui(): template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Ignore dimension settings and do not resize images", value=False, elem_id="train_varsize") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") with FormRow(): -- cgit v1.2.3 From 56ed085edf25482a957b08479206154359f9185d Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Mon, 9 Jan 2023 23:19:08 +0300 Subject: move template so github could use it --- .../PULL_REQUEST_TEMPLATE/pull_request_template.md | 28 ---------------------- .github/pull_request_template.md | 28 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 28 deletions(-) delete mode 100644 .github/PULL_REQUEST_TEMPLATE/pull_request_template.md create mode 100644 .github/pull_request_template.md diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md deleted file mode 100644 index 86009613..00000000 --- a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md +++ /dev/null @@ -1,28 +0,0 @@ -# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request! - -If you have a large change, pay special attention to this paragraph: - -> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature. - -Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on. - -**Describe what this pull request is trying to achieve.** - -A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code. - -**Additional notes and description of your changes** - -More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of. - -**Environment this was tested in** - -List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. - - OS: [e.g. Windows, Linux] - - Browser [e.g. chrome, safari] - - Graphics card [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] - -**Screenshots or videos of your changes** - -If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made. - -This is **required** for anything that touches the user interface. \ No newline at end of file diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000..69056331 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,28 @@ +# Please read the [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) before submitting a pull request! + +If you have a large change, pay special attention to this paragraph: + +> Before making changes, if you think that your feature will result in more than 100 lines changing, find me and talk to me about the feature you are proposing. It pains me to reject the hard work someone else did, but I won't add everything to the repo, and it's better if the rejection happens before you have to waste time working on the feature. + +Otherwise, after making sure you're following the rules described in wiki page, remove this section and continue on. + +**Describe what this pull request is trying to achieve.** + +A clear and concise description of what you're trying to accomplish with this, so your intent doesn't have to be extracted from your code. + +**Additional notes and description of your changes** + +More technical discussion about your changes go here, plus anything that a maintainer might have to specifically take a look at, or be wary of. + +**Environment this was tested in** + +List the environment you have developed / tested this on. As per the contributing page, changes should be able to work on Windows out of the box. + - OS: [e.g. Windows, Linux] + - Browser: [e.g. chrome, safari] + - Graphics card: [e.g. NVIDIA RTX 2080 8GB, AMD RX 6600 8GB] + +**Screenshots or videos of your changes** + +If applicable, screenshots or a video showing off your changes. If it edits an existing UI, it should ideally contain a comparison of what used to be there, before your changes were made. + +This is **required** for anything that touches the user interface. \ No newline at end of file -- cgit v1.2.3 From 1fbb6f9ebe48326a3b12ecf611105dbc4a46891e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 23:35:40 +0300 Subject: make a dropdown for prompt template selection --- modules/hypernetworks/hypernetwork.py | 7 ++++-- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 35 ++++++++++++++++++++------ modules/ui.py | 11 ++++++-- webui.py | 3 +++ 5 files changed, 45 insertions(+), 12 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 32c67ccc..ea3f1db9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,6 +24,7 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} + class HypernetworkModule(torch.nn.Module): multiplier = 1.0 activation_dict = { @@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 - textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + template_file = textual_inversion.textual_inversion_templates.get(template_filename, None) + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + template_file = template_file.path path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() diff --git a/modules/shared.py b/modules/shared.py index a1e10201..aa37c8ce 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 14be2c96..5420903f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -2,6 +2,7 @@ import os import sys import traceback import inspect +from collections import namedtuple import torch import tqdm @@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler -from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, - insert_image_data_embed, extract_image_data_embed, - caption_image_overlay) +from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay from modules.textual_inversion.logging import save_settings_to_file +TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) +textual_inversion_templates = {} + + +def list_textual_inversion_templates(): + textual_inversion_templates.clear() + + for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): + for fn in fns: + path = os.path.join(root, fn) + + textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) + + return textual_inversion_templates + + class Embedding: def __init__(self, vec, name, step=None): self.vec = vec @@ -274,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) -def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" assert isinstance(batch_size, int), "Batch size must be integer" @@ -284,8 +299,9 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert data_root, "Dataset directory is empty" assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - assert template_file, "Prompt template file is empty" - assert os.path.isfile(template_file), "Prompt template file doesn't exist" + assert template_filename, "Prompt template file not selected" + assert template_file, f"Prompt template file {template_filename} not found" + assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" assert steps > 0, "Max steps must be positive" @@ -296,10 +312,13 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 - validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = textual_inversion_templates.get(template_filename, None) + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = template_file.path shared.state.job = "train-embedding" shared.state.textinfo = "Initializing textual inversion training..." diff --git a/modules/ui.py b/modules/ui.py index ddfe1b1a..b6079aec 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -37,7 +37,7 @@ from modules import prompt_parser from modules.images import save_image from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img -import modules.textual_inversion.ui +from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text @@ -1322,6 +1322,9 @@ def create_ui(): outputs=[process_focal_crop_row], ) + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") with FormRow(): @@ -1345,7 +1348,11 @@ def create_ui(): dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") diff --git a/webui.py b/webui.py index 8737e593..47d372c7 100644 --- a/webui.py +++ b/webui.py @@ -33,6 +33,7 @@ import modules.sd_models import modules.sd_vae import modules.txt2img import modules.script_callbacks +import modules.textual_inversion.textual_inversion import modules.ui from modules import modelloader @@ -67,6 +68,8 @@ def initialize(): modules.sd_vae.refresh_vae_list() + modules.textual_inversion.textual_inversion.list_textual_inversion_templates() + try: modules.sd_models.load_model() except Exception as e: -- cgit v1.2.3 From 95727312ca5913876aa1c74f47d1ff6d93bb6b1f Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 9 Jan 2023 16:54:12 -0500 Subject: remove bytes -> gb conversion --- modules/api/api.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index d2222b18..1c121ff0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -467,26 +467,24 @@ class Api: return TrainResponse(info = "train embedding error: {error}".format(error = error)) def get_memory(self): - def gb(val: float): - return round(val / 1024 / 1024 / 1024, 2) try: import os, psutil process = psutil.Process(os.getpid()) - res = process.memory_info() - ram_total = 100 * res.rss / process.memory_percent() - ram = { 'free': gb(ram_total - res.rss), 'used': gb(res.rss), 'total': gb(ram_total) } + res = process.memory_info() # only rss is cross-platform guaranteed so we dont rely on other values + ram_total = 100 * res.rss / process.memory_percent() # and total memory is calculated as actual value is not cross-platform safe + ram = { 'free': ram_total - res.rss, 'used': res.rss, 'total': ram_total } except Exception as err: ram = { 'error': f'{err}' } try: import torch if torch.cuda.is_available(): s = torch.cuda.mem_get_info() - system = { 'free': gb(s[0]), 'used': gb(s[1] - s[0]), 'total': gb(s[1]) } + system = { 'free': s[0], 'used': s[1] - s[0], 'total': s[1] } s = dict(torch.cuda.memory_stats(shared.device)) - allocated = { 'current': gb(s['allocated_bytes.all.current']), 'peak': gb(s['allocated_bytes.all.peak']) } - reserved = { 'current': gb(s['reserved_bytes.all.current']), 'peak': gb(s['reserved_bytes.all.peak']) } - active = { 'current': gb(s['active_bytes.all.current']), 'peak': gb(s['active_bytes.all.peak']) } - inactive = { 'current': gb(s['inactive_split_bytes.all.current']), 'peak': gb(s['inactive_split_bytes.all.peak']) } + allocated = { 'current': s['allocated_bytes.all.current'], 'peak': s['allocated_bytes.all.peak'] } + reserved = { 'current': s['reserved_bytes.all.current'], 'peak': s['reserved_bytes.all.peak'] } + active = { 'current': s['active_bytes.all.current'], 'peak': s['active_bytes.all.peak'] } + inactive = { 'current': s['inactive_split_bytes.all.current'], 'peak': s['inactive_split_bytes.all.peak'] } warnings = { 'retries': s['num_alloc_retries'], 'oom': s['num_ooms'] } cuda = { 'system': system, -- cgit v1.2.3 From 3fe9e9e54dcfc41d7c5ee6976f83b0de29fd3dda Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 02:17:33 +0300 Subject: fix broken resolution detection when pasting parameters with old hires fix enabled --- modules/generation_parameters_copypaste.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index f7f68b67..620aa606 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -198,10 +198,10 @@ def restore_old_hires_fix_params(res): firstpass_height = res.get('First pass size-2', None) if shared.opts.use_old_hires_fix_width_height: - hires_width = int(res.get("Hires resize-1", None)) - hires_height = int(res.get("Hires resize-2", None)) + hires_width = int(res.get("Hires resize-1", 0)) + hires_height = int(res.get("Hires resize-2", 0)) - if hires_width is not None and hires_height is not None: + if hires_width and hires_height: res['Size-1'] = hires_width res['Size-2'] = hires_height return -- cgit v1.2.3 From 552d7b90bf483c160cd20740f7acd7fccbc02e6f Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 9 Jan 2023 18:34:26 -0500 Subject: allow model load if previous model failed --- modules/sd_models.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 76a89e88..0a6d55ca 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -49,6 +49,9 @@ def checkpoint_tiles(): def find_checkpoint_config(info): + if info is None: + return shared.cmd_opts.config + config = os.path.splitext(info.filename)[0] + ".yaml" if os.path.exists(config): return config @@ -345,14 +348,16 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed + current_checkpoint_info = None + else: + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: + return - current_checkpoint_info = sd_model.sd_checkpoint_info checkpoint_config = find_checkpoint_config(current_checkpoint_info) - if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return - - if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.3 From 2275f130bfe437c3245a66559f92af94d0e4d8ff Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 9 Jan 2023 21:23:58 -0500 Subject: relax reponse type check enforcement --- modules/api/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/api/models.py b/modules/api/models.py index 880edde6..034b4aa0 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -262,5 +262,5 @@ class EmbeddingsResponse(BaseModel): skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)") class MemoryResponse(BaseModel): - ram: dict[str, str] | dict[str, float] = Field(title="RAM", description="System memory stats") - cuda: dict[str, str] | dict[str, dict] = Field(title="CUDA", description="nVidia CUDA memory stats") + ram: dict = Field(title="RAM", description="System memory stats") + cuda: dict = Field(title="CUDA", description="nVidia CUDA memory stats") -- cgit v1.2.3 From a4a5475cfa3c68af6cb046081002a72f862ce4be Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Tue, 10 Jan 2023 14:56:57 +0900 Subject: Variable dropout rate Implements variable dropout rate from #4549 Fixes hypernetwork multiplier being able to modified during training, also fixes user-errors by setting multiplier value to lower values for training. Changes function name to match torch.nn.module standard Fixes RNG reset issue when generating previews by restoring RNG state --- modules/hypernetworks/hypernetwork.py | 101 +++++++++++++++++++++++++--------- modules/hypernetworks/ui.py | 4 +- modules/ui.py | 4 +- 3 files changed, 81 insertions(+), 28 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index ea3f1db9..300d3975 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -39,7 +39,7 @@ class HypernetworkModule(torch.nn.Module): activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', - add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False): + add_layer_norm=False, activate_output=False, dropout_structure=None): super().__init__() assert layer_structure is not None, "layer_structure must not be None" @@ -64,9 +64,12 @@ class HypernetworkModule(torch.nn.Module): if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - # Add dropout except last layer - if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2): - linears.append(torch.nn.Dropout(p=0.3)) + # Everything should be now parsed into dropout structure, and applied here. + # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0. + if dropout_structure is not None and dropout_structure[i+1] > 0: + assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!" + linears.append(torch.nn.Dropout(p=dropout_structure[i+1])) + # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0]. self.linear = torch.nn.Sequential(*linears) @@ -113,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x def forward(self, x): - return x + self.linear(x) * self.multiplier + return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1) def trainables(self): layer_structure = [] @@ -126,6 +129,21 @@ class HypernetworkModule(torch.nn.Module): def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength +#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check. +def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout): + if layer_structure is None: + layer_structure = [1, 2, 1] + if not use_dropout: + return [0] * len(layer_structure) + dropout_values = [0] + dropout_values.extend([0.3] * (len(layer_structure) - 3)) + if last_layer_dropout: + dropout_values.append(0.3) + else: + dropout_values.append(0) + dropout_values.append(0) + return dropout_values + class Hypernetwork: filename = None @@ -144,18 +162,22 @@ class Hypernetwork: self.add_layer_norm = add_layer_norm self.use_dropout = use_dropout self.activate_output = activate_output - self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True + self.last_layer_dropout = kwargs.get('last_layer_dropout', True) + self.dropout_structure = kwargs.get('dropout_structure', None) + if self.dropout_structure is None: + self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) self.optimizer_name = None self.optimizer_state_dict = None + self.optional_info = None for size in enable_sizes or []: self.layers[size] = ( HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure), ) - self.eval_mode() + self.eval() def weights(self): res = [] @@ -164,14 +186,14 @@ class Hypernetwork: res += layer.parameters() return res - def train_mode(self): + def train(self, mode=True): for k, layers in self.layers.items(): for layer in layers: - layer.train() + layer.train(mode=mode) for param in layer.parameters(): - param.requires_grad = True + param.requires_grad = mode - def eval_mode(self): + def eval(self): for k, layers in self.layers.items(): for layer in layers: layer.eval() @@ -191,11 +213,13 @@ class Hypernetwork: state_dict['activation_func'] = self.activation_func state_dict['is_layer_norm'] = self.add_layer_norm state_dict['weight_initialization'] = self.weight_init - state_dict['use_dropout'] = self.use_dropout state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['activate_output'] = self.activate_output - state_dict['last_layer_dropout'] = self.last_layer_dropout + state_dict['use_dropout'] = self.use_dropout + state_dict['dropout_structure'] = self.dropout_structure + state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout + state_dict['optional_info'] = self.optional_info if self.optional_info else None if self.optimizer_name is not None: optimizer_saved_dict['optimizer_name'] = self.optimizer_name @@ -215,43 +239,56 @@ class Hypernetwork: self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) print(self.layer_structure) + optional_info = state_dict.get('optional_info', None) + if optional_info is not None: + print(f"INFO:\n {optional_info}\n") + self.optional_info = optional_info self.activation_func = state_dict.get('activation_func', None) print(f"Activation function is {self.activation_func}") self.weight_init = state_dict.get('weight_initialization', 'Normal') print(f"Weight initialization is {self.weight_init}") self.add_layer_norm = state_dict.get('is_layer_norm', False) print(f"Layer norm is set to {self.add_layer_norm}") - self.use_dropout = state_dict.get('use_dropout', False) + self.dropout_structure = state_dict.get('dropout_structure', None) + self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False) print(f"Dropout usage is set to {self.use_dropout}" ) self.activate_output = state_dict.get('activate_output', True) print(f"Activate last layer is set to {self.activate_output}") self.last_layer_dropout = state_dict.get('last_layer_dropout', False) + # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0. + if self.dropout_structure is None: + print("Using previous dropout structure") + self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) + print(f"Dropout structure is set to {self.dropout_structure}") optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} - self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') - print(f"Optimizer name is {self.optimizer_name}") + if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) else: self.optimizer_state_dict = None if self.optimizer_state_dict: + self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') print("Loaded existing optimizer from checkpoint") + print(f"Optimizer name is {self.optimizer_name}") else: + self.optimizer_name = "AdamW" print("No saved optimizer exists in checkpoint") for size, sd in state_dict.items(): if type(size) == int: self.layers[size] = ( HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, self.dropout_structure), HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, - self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout), + self.add_layer_norm, self.activate_output, self.dropout_structure), ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) + self.eval() def list_hypernetworks(path): @@ -379,9 +416,10 @@ def report_statistics(loss_info:dict): print(e) -def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) + assert name, "Name cannot be empty!" fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") if not overwrite_old: @@ -390,6 +428,11 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, if type(layer_structure) == str: layer_structure = [float(x.strip()) for x in layer_structure.split(",")] + if use_dropout and dropout_structure and type(dropout_structure) == str: + dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")] + else: + dropout_structure = [0] * len(layer_structure) + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( name=name, enable_sizes=[int(x) for x in enable_sizes], @@ -398,6 +441,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, weight_init=weight_init, add_layer_norm=add_layer_norm, use_dropout=use_dropout, + dropout_structure=dropout_structure ) hypernet.save(fn) @@ -480,7 +524,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, shared.sd_model.first_stage_model.to(devices.cpu) weights = hypernetwork.weights() - hypernetwork.train_mode() + hypernetwork.train() # Here we use optimizer from saved HN, or we can specify as UI option. if hypernetwork.optimizer_name in optimizer_dict: @@ -594,7 +638,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if images_dir is not None and steps_done % create_image_every == 0: forced_filename = f'{hypernetwork_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) - hypernetwork.eval_mode() + hypernetwork.eval() + rng_state = torch.get_rng_state() + cuda_rng_state = None + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state_all() shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) @@ -627,7 +675,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - hypernetwork.train_mode() + torch.set_rng_state(rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state_all(cuda_rng_state) + hypernetwork.train() if image is not None: shared.state.current_image = image last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) @@ -649,7 +700,7 @@ Last saved image: {html.escape(last_saved_image)}
finally: pbar.leave = False pbar.close() - hypernetwork.eval_mode() + hypernetwork.eval() #report_statistics(loss_dict) filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index e7f9e593..81e3f519 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,8 +9,8 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"] keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) -def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): - filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout) +def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): + filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {filename}", "" diff --git a/modules/ui.py b/modules/ui.py index b6079aec..9b9081b5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1268,6 +1268,7 @@ def create_ui(): new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") with gr.Row(): @@ -1414,7 +1415,8 @@ def create_ui(): new_hypernetwork_activation_func, new_hypernetwork_initialization_option, new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure ], outputs=[ train_hypernetwork_name, -- cgit v1.2.3 From e9f8292a3a6792b722696fcf8e32b3fcb43ba436 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Tue, 10 Jan 2023 11:54:48 +0300 Subject: Split history ui.py to ui_progress.py --- modules/ui.py | 1928 ------------------------------------------------ modules/ui_progress.py | 1928 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1928 insertions(+), 1928 deletions(-) delete mode 100644 modules/ui.py create mode 100644 modules/ui_progress.py diff --git a/modules/ui.py b/modules/ui.py deleted file mode 100644 index 9b9081b5..00000000 --- a/modules/ui.py +++ /dev/null @@ -1,1928 +0,0 @@ -import html -import json -import math -import mimetypes -import os -import platform -import random -import subprocess as sp -import sys -import tempfile -import time -import traceback -from functools import partial, reduce - -import gradio as gr -import gradio.routes -import gradio.utils -import numpy as np -from PIL import Image, PngImagePlugin -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call - -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path - -from modules.shared import opts, cmd_opts, restricted_opts - -import modules.codeformer_model -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.scripts -import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui -from modules import prompt_parser -from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.textual_inversion import textual_inversion -import modules.hypernetworks.ui -from modules.generation_parameters_copypaste import image_from_url_text - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -if not cmd_opts.share and not cmd_opts.listen: - # fix gradio phoning home - gradio.utils.version_check = lambda: None - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' - -if cmd_opts.ngrok is not None: - import modules.ngrok as ngrok - print('ngrok authtoken detected, trying to connect...') - ngrok.connect( - cmd_opts.ngrok, - cmd_opts.port if cmd_opts.port is not None else 7860, - cmd_opts.ngrok_region - ) - - -def gr_show(visible=True): - return {"visible": visible, "__type__": "update"} - - -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - -css_hide_progressbar = """ -.wrap .m-12 svg { display:none!important; } -.wrap .m-12::before { content:"Loading..." } -.wrap .z-20 svg { display:none!important; } -.wrap .z-20::before { content:"Loading..." } -.progress-bar { display:none!important; } -.meta-text { display:none!important; } -.meta-text-center { display:none!important; } -""" - -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -random_symbol = '\U0001f3b2\ufe0f' # 🎲️ -reuse_symbol = '\u267b\ufe0f' # ♻️ -paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ - - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text - -def send_gradio_gallery_to_image(x): - if len(x) == 0: - return None - return image_from_url_text(x[0]) - -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr_show(False) - preview_visibility = gr_show(False) - - if opts.show_progress_every_n_steps != 0: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr_show(True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr_show(False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def visit(x, func, path=""): - if hasattr(x, 'children'): - for c in x.children: - visit(c, func, path) - elif x.label is not None: - func(path + "/" + str(x.label), x) - - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] - - -def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): - from modules import processing, devices - - if not enable: - return "" - - p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) - - with devices.autocast(): - p.init([""], [0], [0]) - - return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" - - -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] - - -def interrogate(image): - prompt = shared.interrogator.interrogate(image.convert("RGB")) - - return gr_show(True) if prompt is None else prompt - - -def interrogate_deepbooru(image): - prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt - - -def create_seed_inputs(target_interface): - with FormRow(elem_id=target_interface + '_seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - - with gr.Group(elem_id=target_interface + '_subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - - with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: - seed_extras.append(seed_extra_row_1) - subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') - - with FormRow(visible=False) as seed_extra_row_2: - seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') - - random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) - random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - - def change_visibility(show): - return {comp: gr_show(show) for comp in seed_extras} - - seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) - - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - - - -def connect_clear_prompt(button): - """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" - button.click( - _js="clear_prompt", - fn=None, - inputs=[], - outputs=[], - ) - - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): - """ Connects a 'reuse (sub)seed' button's click event so that it copies last used - (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength - was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): - res = -1 - - try: - gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError as e: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) - - return [res, gr_show(False)] - - reuse_seed.click( - fn=copy_seed, - _js="(x, y) => [x, selected_gallery_index()]", - show_progress=False, - inputs=[generation_info, dummy_component], - outputs=[seed, dummy_component] - ) - - -def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" - - -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1): - with gr.Row(): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - -def apply_setting(key, value): - if value is None: - return gr.update() - - if shared.cmd_opts.freeze_settings: - return gr.update() - - # dont allow model to be swapped when model hash exists in prompt - if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: - return gr.update() - - if key == "sd_model_checkpoint": - ckpt_info = sd_models.get_closet_checkpoint_match(value) - - if ckpt_info is not None: - value = ckpt_info.title - else: - return gr.update() - - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - return - - valtype = type(opts.data_labels[key].default) - oldval = opts.data.get(key, None) - opts.data[key] = valtype(value) if valtype != type(None) else value - if oldval != value and opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - - opts.save(shared.config_filename) - return value - - -def update_generation_info(args): - generation_info, html_info, img_index = args - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info - - -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - -def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log - - -def create_sampler_and_steps_selection(choices, tabname): - if opts.samplers_in_dropdown: - with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - else: - with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - - return steps, sampler_index - - -def ordered_ui_categories(): - user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} - - for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): - yield category - - -def create_ui(): - import modules.img2img - import modules.txt2img - - reload_javascript() - - parameters_copypaste.reset() - - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - - with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - - dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel', elem_id="txt2img_settings"): - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "cfg": - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - elif category == "checkboxes": - with FormRow(elem_id="txt2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) - - elif category == "hires_fix": - with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: - with FormRow(elem_id="txt2img_hires_fix_row1"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - - with FormRow(elem_id="txt2img_hires_fix_row2"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - - hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - for input in hr_resolution_preview_inputs: - input.change( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False, - ) - input.change( - None, - _js="onCalcResolutionHires", - inputs=hr_resolution_preview_inputs, - outputs=[], - show_progress=False, - ) - - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), - _js="submit", - inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ] - ) - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - show_progress = False, - ) - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - *modules.scripts.scripts_txt2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) - - txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, - steps, - sampler_index, - cfg_scale, - seed, - width, - height, - ] - - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') - - with FormRow().style(equal_height=False): - with gr.Column(variant='panel', elem_id="img2img_settings"): - - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) - - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) - - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "cfg": - with FormGroup(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - - elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), - _js="submit_img2img", - inputs=[ - dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, - init_img, - init_img_with_mask, - init_img_with_mask_orig, - init_img_inpaint, - init_mask_inpaint, - mask_mode, - steps, - sampler_index, - mask_blur, - mask_alpha, - inpainting_fill, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) - - img2img_interrogate.click( - fn=interrogate, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], - ) - - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - - img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - - modules.scripts.scripts_current = None - - with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) - - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") - - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") - - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") - - custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - - with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") - save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row().style(equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") - initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") - new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - def get_textual_inversion_template_names(): - return sorted([x for x in textual_inversion.textual_inversion_templates]) - - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with FormRow(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - - with FormRow(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - - with FormRow(): - clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) - - with FormRow(): - batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") - - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - - with FormRow(): - template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) - create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") - - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") - steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") - - with FormRow(): - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") - - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") - - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") - - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") - - with gr.Row(): - train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") - interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) - - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout, - new_hypernetwork_dropout_structure - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_tab.__exit__() - - current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) - current_tab.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_tab.__exit__() - - with gr.TabItem("Actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] - - css = "" - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - with open(cssfile, "r", encoding="utf8") as file: - css += file.read() + "\n" - - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - css += file.read() + "\n" - - if not cmd_opts.no_progressbar_hiding: - css += css_hide_progressbar - - interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] - - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): - for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component - - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() - - with gr.Tabs(elem_id="tabs") as tabs: - for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): - interface.render() - - if os.path.exists(os.path.join(script_path, "notification.mp3")): - audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") - - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for i, k, item in quicksettings_list: - component = component_dict[k] - - component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - ) - - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - checkpoint_format, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) - - ui_config_file = cmd_opts.ui_config_file - ui_settings = {} - settings_count = len(ui_settings) - error_loading = False - - try: - if os.path.exists(ui_config_file): - with open(ui_config_file, "r", encoding="utf8") as file: - ui_settings = json.load(file) - except Exception: - error_loading = True - print("Error loading settings:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - def loadsave(path, x): - def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field - - if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key - - if getattr(obj, 'do_not_save_to_config', False): - return - - saved_value = ui_settings.get(key, None) - if saved_value is None: - ui_settings[key] = getattr(obj, field) - elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') - else: - setattr(obj, field, saved_value) - if init_field is not None: - init_field(saved_value) - - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: - apply_field(x, 'visible') - - if type(x) == gr.Slider: - apply_field(x, 'value') - apply_field(x, 'minimum') - apply_field(x, 'maximum') - apply_field(x, 'step') - - if type(x) == gr.Radio: - apply_field(x, 'value', lambda val: val in x.choices) - - if type(x) == gr.Checkbox: - apply_field(x, 'value') - - if type(x) == gr.Textbox: - apply_field(x, 'value') - - if type(x) == gr.Number: - apply_field(x, 'value') - - if type(x) == gr.Dropdown: - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - - visit(txt2img_interface, loadsave, "txt2img") - visit(img2img_interface, loadsave, "img2img") - visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") - visit(train_interface, loadsave, "train") - - if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): - with open(ui_config_file, "w", encoding="utf8") as file: - json.dump(ui_settings, file, indent=4) - - return demo - - -def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" - - if cmd_opts.theme is not None: - javascript += f"\n\n" - - javascript += f"\n" - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - -def versions_html(): - import torch - import launch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = launch.commit_hash() - short_commit = commit[0:8] - - if shared.xformers_available: - import xformers - xformers_version = xformers.__version__ - else: - xformers_version = "N/A" - - return f""" -python: {python_version} - •  -torch: {torch.__version__} - •  -xformers: {xformers_version} - •  -gradio: {gr.__version__} - •  -commit: {short_commit} -""" diff --git a/modules/ui_progress.py b/modules/ui_progress.py new file mode 100644 index 00000000..9b9081b5 --- /dev/null +++ b/modules/ui_progress.py @@ -0,0 +1,1928 @@ +import html +import json +import math +import mimetypes +import os +import platform +import random +import subprocess as sp +import sys +import tempfile +import time +import traceback +from functools import partial, reduce + +import gradio as gr +import gradio.routes +import gradio.utils +import numpy as np +from PIL import Image, PngImagePlugin +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call + +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML +from modules.paths import script_path + +from modules.shared import opts, cmd_opts, restricted_opts + +import modules.codeformer_model +import modules.generation_parameters_copypaste as parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.images import save_image +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img +from modules.textual_inversion import textual_inversion +import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +if not cmd_opts.share and not cmd_opts.listen: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + +if cmd_opts.ngrok is not None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect( + cmd_opts.ngrok, + cmd_opts.port if cmd_opts.port is not None else 7860, + cmd_opts.ngrok_region + ) + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +css_hide_progressbar = """ +.wrap .m-12 svg { display:none!important; } +.wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } +.progress-bar { display:none!important; } +.meta-text { display:none!important; } +.meta-text-center { display:none!important; } +""" + +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +random_symbol = '\U0001f3b2\ufe0f' # 🎲️ +reuse_symbol = '\u267b\ufe0f' # ♻️ +paste_symbol = '\u2199\ufe0f' # ↙ +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 +clear_prompt_symbol = '\U0001F5D1' # 🗑️ + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + return image_from_url_text(x[0]) + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = opts.outdir_save + save_to_dirs = opts.use_save_to_dirs_for_ui + extension: str = opts.samples_format + start_index = 0 + + if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(opts.outdir_save, exist_ok=True) + + with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def calc_time_left(progress, threshold, label, force_display, show_eta): + if progress == 0: + return "" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + if (eta_relative > threshold and show_eta) or force_display: + if eta_relative > 3600: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + elif eta_relative > 60: + return label + time.strftime('%M:%S', time.gmtime(eta_relative)) + else: + return label + time.strftime('%Ss', time.gmtime(eta_relative)) + else: + return "" + + +def check_progress_call(id_part): + if shared.state.job_count == 0: + return "", gr_show(False), gr_show(False), gr_show(False) + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + # Show progress percentage and time left at the same moment, and base it also on steps done + show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 + + time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) + if time_left != "": + shared.state.time_left_force_display = True + + progress = min(progress, 1) + + progressbar = "" + if opts.show_progressbar: + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" + + image = gr_show(False) + preview_visibility = gr_show(False) + + if opts.show_progress_every_n_steps != 0: + shared.state.set_current_image() + image = shared.state.current_image + + if image is None: + image = gr.update(value=None) + else: + preview_visibility = gr_show(True) + + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result + + +def check_progress_call_initial(id_part): + shared.state.job_count = -1 + shared.state.current_latent = None + shared.state.current_image = None + shared.state.textinfo = None + shared.state.time_start = time.time() + shared.state.time_left_force_display = False + + return check_progress_call(id_part) + + +def visit(x, func, path=""): + if hasattr(x, 'children'): + for c in x.children: + visit(c, func, path) + elif x.label is not None: + func(path + "/" + str(x.label), x) + + +def add_style(name: str, prompt: str, negative_prompt: str): + if name is None: + return [gr_show() for x in range(4)] + + style = modules.styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we + # reserialize all styles every time we save them + shared.prompt_styles.save_styles(shared.styles_filename) + + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] + + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" + + +def apply_styles(prompt, prompt_neg, style1_name, style2_name): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] + + +def interrogate(image): + prompt = shared.interrogator.interrogate(image.convert("RGB")) + + return gr_show(True) if prompt is None else prompt + + +def interrogate_deepbooru(image): + prompt = deepbooru.model.tag(image) + return gr_show(True) if prompt is None else prompt + + +def create_seed_inputs(target_interface): + with FormRow(elem_id=target_interface + '_seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') + seed.style(container=False) + random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') + + with gr.Group(elem_id=target_interface + '_subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) + + # Components to show/hide based on the 'Extra' checkbox + seed_extras = [] + + with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: + seed_extras.append(seed_extra_row_1) + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') + subseed.style(container=False) + random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') + + with FormRow(visible=False) as seed_extra_row_2: + seed_extras.append(seed_extra_row_2) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') + + random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) + random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) + + def change_visibility(show): + return {comp: gr_show(show) for comp in seed_extras} + + seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) + + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox + + + +def connect_clear_prompt(button): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" + button.click( + _js="clear_prompt", + fn=None, + inputs=[], + outputs=[], + ) + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError as e: + if gen_info_string != '': + print("Error parsing JSON generation info:", file=sys.stderr) + print(gen_info_string, file=sys.stderr) + + return [res, gr_show(False)] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, dummy_component], + outputs=[seed, dummy_component] + ) + + +def update_token_counter(text, steps): + try: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" + + +def create_toprow(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Row(elem_id="toprow"): + with gr.Column(scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Column(scale=1, elem_id="roll_col"): + paste = gr.Button(value=paste_symbol, elem_id="paste") + save_style = gr.Button(value=save_style_symbol, elem_id="style_create") + prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + + clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[prompt, negative_prompt], + outputs=[prompt, negative_prompt], + ) + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1): + with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + with gr.Row(): + with gr.Column(scale=1, elem_id="style_pos_col"): + prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + with gr.Column(scale=1, elem_id="style_neg_col"): + prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + + +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) + check_progress.click( + fn=lambda: check_progress_call(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) + check_progress_initial.click( + fn=lambda: check_progress_call_initial(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + +def apply_setting(key, value): + if value is None: + return gr.update() + + if shared.cmd_opts.freeze_settings: + return gr.update() + + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data.get(key, None) + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return value + + +def update_generation_info(args): + generation_info, html_info, img_index = args + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info + return plaintext_to_html(generation_info["infotexts"][img_index]) + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_output_panel(tabname, outdir): + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel'): + with gr.Group(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="(x, y) => [x, y, selected_gallery_index()]", + inputs=[generation_info, html_info], + outputs=[html_info], + preprocess=False + ) + + save.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + save_zip.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + + +def create_sampler_and_steps_selection(choices, tabname): + if opts.samplers_in_dropdown: + with FormRow(elem_id=f"sampler_selection_{tabname}"): + sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + else: + with FormGroup(elem_id=f"sampler_selection_{tabname}"): + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + + return steps, sampler_index + + +def ordered_ui_categories(): + user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} + + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): + yield category + + +def create_ui(): + import modules.img2img + import modules.txt2img + + reload_javascript() + + parameters_copypaste.reset() + + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + + dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Row(elem_id='txt2img_progress_row'): + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="txt2img_progressbar") + txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) + setup_progressbar(progressbar, txt2img_preview, 'txt2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel', elem_id="txt2img_settings"): + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "cfg": + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + elif category == "checkboxes": + with FormRow(elem_id="txt2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) + + elif category == "hires_fix": + with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: + with FormRow(elem_id="txt2img_hires_fix_row1"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + with FormRow(elem_id="txt2img_hires_fix_row2"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + for input in hr_resolution_preview_inputs: + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) + + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + txt2img_args = dict( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), + _js="submit", + inputs=[ + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + ] + custom_inputs, + + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + txt2img_prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + + enable_hr.change( + fn=lambda x: gr_show(x), + inputs=[enable_hr], + outputs=[hr_options], + show_progress = False, + ) + + txt2img_paste_fields = [ + (txt2img_prompt, "Prompt"), + (txt2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (enable_hr, lambda d: "Denoising strength" in d), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), + *modules.scripts.scripts_txt2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) + + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) + + with gr.Row(elem_id='img2img_progress_row'): + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="img2img_progressbar") + img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) + setup_progressbar(progressbar, img2img_preview, 'img2img') + + with FormRow().style(equal_height=False): + with gr.Column(variant='panel', elem_id="img2img_settings"): + + with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) + init_img_with_mask_orig = gr.State(None) + + use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" + if use_color_sketch: + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) + + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") + + with FormRow(): + mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "cfg": + with FormGroup(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + + elif category == "checkboxes": + with FormRow(elem_id="img2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() + + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + + mask_mode.change( + lambda mode, img: { + init_img_with_mask: gr_show(mode == 0), + init_img_inpaint: gr_show(mode == 1), + init_mask_inpaint: gr_show(mode == 1), + }, + inputs=[mask_mode, init_img_with_mask], + outputs=[ + init_img_with_mask, + init_img_inpaint, + init_mask_inpaint, + ], + ) + + img2img_args = dict( + fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), + _js="submit_img2img", + inputs=[ + dummy_component, + img2img_prompt, + img2img_negative_prompt, + img2img_prompt_style, + img2img_prompt_style2, + init_img, + init_img_with_mask, + init_img_with_mask_orig, + init_img_inpaint, + init_mask_inpaint, + mask_mode, + steps, + sampler_index, + mask_blur, + mask_alpha, + inpainting_fill, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + denoising_strength, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + ] + custom_inputs, + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + img2img_prompt.submit(**img2img_args) + submit.click(**img2img_args) + + img2img_interrogate.click( + fn=interrogate, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] + style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] + + for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): + button.click( + fn=add_style, + _js="ask_for_style_name", + # Have to pass empty dummy component here, because the JavaScript and Python function have to accept + # the same number of parameters, but we only know the style-name after the JavaScript prompt + inputs=[dummy_component, prompt, negative_prompt], + outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], + ) + + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + button.click( + fn=apply_styles, + _js=js_func, + inputs=[prompt, negative_prompt, style1, style2], + outputs=[prompt, negative_prompt, style1, style2], + ) + + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + + img2img_paste_fields = [ + (img2img_prompt, "Prompt"), + (img2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (mask_blur, "Mask blur"), + *modules.scripts.scripts_img2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + + modules.scripts.scripts_current = None + + with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab"): + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): + with gr.Group(): + with gr.Row(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with gr.Group(): + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + + with gr.Group(): + extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") + + with gr.Group(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") + + with gr.Group(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") + + with gr.Group(): + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") + + result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) + + submit.click( + fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), + _js="get_extras_tab_index", + inputs=[ + dummy_component, + dummy_component, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + gfpgan_visibility, + codeformer_visibility, + codeformer_weight, + upscaling_resize, + upscaling_resize_w, + upscaling_resize_h, + upscaling_crop, + extras_upscaler_1, + extras_upscaler_2, + extras_upscaler_2_visibility, + upscale_before_face_fix, + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=modules.extras.clear_cache, + inputs=[], outputs=[] + ) + + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + with gr.Column(variant='panel'): + html = gr.HTML() + generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") + html2 = gr.HTML() + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + parameters_copypaste.bind_buttons(buttons, image, generation_info) + + image.change( + fn=wrap_gradio_call(modules.extras.run_pnginfo), + inputs=[image], + outputs=[html, generation_info, html2], + ) + + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + + with gr.Row(): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + + custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + + with gr.Row(): + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + + with gr.Column(variant='panel'): + submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + + with gr.Blocks(analytics_enabled=False) as train_interface: + with gr.Row().style(equal_height=False): + gr.HTML(value="

See wiki for detailed explanation.

") + + with gr.Row().style(equal_height=False): + with gr.Tabs(elem_id="train_tabs"): + + with gr.Tab(label="Create embedding"): + new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") + + with gr.Tab(label="Create hypernetwork"): + new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") + + with gr.Tab(label="Preprocess images"): + process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") + + with gr.Row(): + process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") + + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") + + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") + + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + + with gr.Tab(label="Train"): + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + with FormRow(): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + + with FormRow(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + with FormRow(): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) + + with FormRow(): + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") + steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + + with FormRow(): + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") + + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") + + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") + + with gr.Row(): + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") + interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + overwrite_old_hypernetwork, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_width, + process_height, + preprocess_txt_action, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + embedding_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + elem_id = "setting_"+key + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + components = [] + component_dict = {} + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + result = gr.HTML(elem_id="settings_result") + + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} + + quicksettings_list = [] + + previous_section = None + current_tab = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_tab.__exit__() + + current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) + current_tab.__enter__() + + previous_section = item.section + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + + if current_tab is not None: + current_tab.__exit__() + + with gr.TabItem("Actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + + if os.path.exists("html/licenses.html"): + with open("html/licenses.html", encoding="utf8") as file: + with gr.TabItem("Licenses"): + gr.HTML(file.read(), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + def request_restart(): + shared.state.interrupt() + shared.state.need_restart = True + + restart_gradio.click( + fn=request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "ti"), + ] + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" + + if os.path.exists(os.path.join(script_path, "user.css")): + with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + css += file.read() + "\n" + + if not cmd_opts.no_progressbar_hiding: + css += css_hide_progressbar + + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Row(elem_id="quicksettings"): + for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + component_dict[k] = component + + parameters_copypaste.integrate_settings_paste_fields(component_dict) + parameters_copypaste.run_bind() + + with gr.Tabs(elem_id="tabs") as tabs: + for interface, label, ifid in interfaces: + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): + interface.render() + + if os.path.exists(os.path.join(script_path, "notification.mp3")): + audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + + if os.path.exists("html/footer.html"): + with open("html/footer.html", encoding="utf8") as file: + footer = file.read() + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") + + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + settings_submit.click( + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), + inputs=components, + outputs=[text_settings, result], + ) + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [getattr(opts, key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] + return results + + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + checkpoint_format, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) + + ui_config_file = cmd_opts.ui_config_file + ui_settings = {} + settings_count = len(ui_settings) + error_loading = False + + try: + if os.path.exists(ui_config_file): + with open(ui_config_file, "r", encoding="utf8") as file: + ui_settings = json.load(file) + except Exception: + error_loading = True + print("Error loading settings:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def loadsave(path, x): + def apply_field(obj, field, condition=None, init_field=None): + key = path + "/" + field + + if getattr(obj, 'custom_script_source', None) is not None: + key = 'customscript/' + obj.custom_script_source + '/' + key + + if getattr(obj, 'do_not_save_to_config', False): + return + + saved_value = ui_settings.get(key, None) + if saved_value is None: + ui_settings[key] = getattr(obj, field) + elif condition and not condition(saved_value): + print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + else: + setattr(obj, field, saved_value) + if init_field is not None: + init_field(saved_value) + + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: + apply_field(x, 'visible') + + if type(x) == gr.Slider: + apply_field(x, 'value') + apply_field(x, 'minimum') + apply_field(x, 'maximum') + apply_field(x, 'step') + + if type(x) == gr.Radio: + apply_field(x, 'value', lambda val: val in x.choices) + + if type(x) == gr.Checkbox: + apply_field(x, 'value') + + if type(x) == gr.Textbox: + apply_field(x, 'value') + + if type(x) == gr.Number: + apply_field(x, 'value') + + if type(x) == gr.Dropdown: + apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + + visit(txt2img_interface, loadsave, "txt2img") + visit(img2img_interface, loadsave, "img2img") + visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") + visit(train_interface, loadsave, "train") + + if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): + with open(ui_config_file, "w", encoding="utf8") as file: + json.dump(ui_settings, file, indent=4) + + return demo + + +def reload_javascript(): + with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: + javascript = f'' + + scripts_list = modules.scripts.list_scripts("javascript", ".js") + + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: + javascript += f"\n" + + if cmd_opts.theme is not None: + javascript += f"\n\n" + + javascript += f"\n" + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gradio.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + + +def versions_html(): + import torch + import launch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = launch.commit_hash() + short_commit = commit[0:8] + + if shared.xformers_available: + import xformers + xformers_version = xformers.__version__ + else: + xformers_version = "N/A" + + return f""" +python: {python_version} + •  +torch: {torch.__version__} + •  +xformers: {xformers_version} + •  +gradio: {gr.__version__} + •  +commit: {short_commit} +""" -- cgit v1.2.3 From 27ea6949d3206c9a52fa77db587bac0012cb0b52 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Tue, 10 Jan 2023 11:54:48 +0300 Subject: Split history ui.py to ui_progress.py --- modules/temp | 1928 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ modules/ui.py | 1928 --------------------------------------------------------- 2 files changed, 1928 insertions(+), 1928 deletions(-) create mode 100644 modules/temp delete mode 100644 modules/ui.py diff --git a/modules/temp b/modules/temp new file mode 100644 index 00000000..9b9081b5 --- /dev/null +++ b/modules/temp @@ -0,0 +1,1928 @@ +import html +import json +import math +import mimetypes +import os +import platform +import random +import subprocess as sp +import sys +import tempfile +import time +import traceback +from functools import partial, reduce + +import gradio as gr +import gradio.routes +import gradio.utils +import numpy as np +from PIL import Image, PngImagePlugin +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call + +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML +from modules.paths import script_path + +from modules.shared import opts, cmd_opts, restricted_opts + +import modules.codeformer_model +import modules.generation_parameters_copypaste as parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.images import save_image +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img +from modules.textual_inversion import textual_inversion +import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +if not cmd_opts.share and not cmd_opts.listen: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + +if cmd_opts.ngrok is not None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect( + cmd_opts.ngrok, + cmd_opts.port if cmd_opts.port is not None else 7860, + cmd_opts.ngrok_region + ) + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +css_hide_progressbar = """ +.wrap .m-12 svg { display:none!important; } +.wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } +.progress-bar { display:none!important; } +.meta-text { display:none!important; } +.meta-text-center { display:none!important; } +""" + +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +random_symbol = '\U0001f3b2\ufe0f' # 🎲️ +reuse_symbol = '\u267b\ufe0f' # ♻️ +paste_symbol = '\u2199\ufe0f' # ↙ +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 +clear_prompt_symbol = '\U0001F5D1' # 🗑️ + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + return image_from_url_text(x[0]) + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = opts.outdir_save + save_to_dirs = opts.use_save_to_dirs_for_ui + extension: str = opts.samples_format + start_index = 0 + + if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(opts.outdir_save, exist_ok=True) + + with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def calc_time_left(progress, threshold, label, force_display, show_eta): + if progress == 0: + return "" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + if (eta_relative > threshold and show_eta) or force_display: + if eta_relative > 3600: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + elif eta_relative > 60: + return label + time.strftime('%M:%S', time.gmtime(eta_relative)) + else: + return label + time.strftime('%Ss', time.gmtime(eta_relative)) + else: + return "" + + +def check_progress_call(id_part): + if shared.state.job_count == 0: + return "", gr_show(False), gr_show(False), gr_show(False) + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + # Show progress percentage and time left at the same moment, and base it also on steps done + show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 + + time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) + if time_left != "": + shared.state.time_left_force_display = True + + progress = min(progress, 1) + + progressbar = "" + if opts.show_progressbar: + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" + + image = gr_show(False) + preview_visibility = gr_show(False) + + if opts.show_progress_every_n_steps != 0: + shared.state.set_current_image() + image = shared.state.current_image + + if image is None: + image = gr.update(value=None) + else: + preview_visibility = gr_show(True) + + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result + + +def check_progress_call_initial(id_part): + shared.state.job_count = -1 + shared.state.current_latent = None + shared.state.current_image = None + shared.state.textinfo = None + shared.state.time_start = time.time() + shared.state.time_left_force_display = False + + return check_progress_call(id_part) + + +def visit(x, func, path=""): + if hasattr(x, 'children'): + for c in x.children: + visit(c, func, path) + elif x.label is not None: + func(path + "/" + str(x.label), x) + + +def add_style(name: str, prompt: str, negative_prompt: str): + if name is None: + return [gr_show() for x in range(4)] + + style = modules.styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we + # reserialize all styles every time we save them + shared.prompt_styles.save_styles(shared.styles_filename) + + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] + + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" + + +def apply_styles(prompt, prompt_neg, style1_name, style2_name): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] + + +def interrogate(image): + prompt = shared.interrogator.interrogate(image.convert("RGB")) + + return gr_show(True) if prompt is None else prompt + + +def interrogate_deepbooru(image): + prompt = deepbooru.model.tag(image) + return gr_show(True) if prompt is None else prompt + + +def create_seed_inputs(target_interface): + with FormRow(elem_id=target_interface + '_seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') + seed.style(container=False) + random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') + + with gr.Group(elem_id=target_interface + '_subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) + + # Components to show/hide based on the 'Extra' checkbox + seed_extras = [] + + with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: + seed_extras.append(seed_extra_row_1) + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') + subseed.style(container=False) + random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') + + with FormRow(visible=False) as seed_extra_row_2: + seed_extras.append(seed_extra_row_2) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') + + random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) + random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) + + def change_visibility(show): + return {comp: gr_show(show) for comp in seed_extras} + + seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) + + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox + + + +def connect_clear_prompt(button): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" + button.click( + _js="clear_prompt", + fn=None, + inputs=[], + outputs=[], + ) + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError as e: + if gen_info_string != '': + print("Error parsing JSON generation info:", file=sys.stderr) + print(gen_info_string, file=sys.stderr) + + return [res, gr_show(False)] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, dummy_component], + outputs=[seed, dummy_component] + ) + + +def update_token_counter(text, steps): + try: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" + + +def create_toprow(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Row(elem_id="toprow"): + with gr.Column(scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Column(scale=1, elem_id="roll_col"): + paste = gr.Button(value=paste_symbol, elem_id="paste") + save_style = gr.Button(value=save_style_symbol, elem_id="style_create") + prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + + clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[prompt, negative_prompt], + outputs=[prompt, negative_prompt], + ) + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1): + with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + with gr.Row(): + with gr.Column(scale=1, elem_id="style_pos_col"): + prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + with gr.Column(scale=1, elem_id="style_neg_col"): + prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + + +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) + check_progress.click( + fn=lambda: check_progress_call(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) + check_progress_initial.click( + fn=lambda: check_progress_call_initial(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + +def apply_setting(key, value): + if value is None: + return gr.update() + + if shared.cmd_opts.freeze_settings: + return gr.update() + + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data.get(key, None) + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return value + + +def update_generation_info(args): + generation_info, html_info, img_index = args + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info + return plaintext_to_html(generation_info["infotexts"][img_index]) + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_output_panel(tabname, outdir): + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel'): + with gr.Group(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="(x, y) => [x, y, selected_gallery_index()]", + inputs=[generation_info, html_info], + outputs=[html_info], + preprocess=False + ) + + save.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + save_zip.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + + +def create_sampler_and_steps_selection(choices, tabname): + if opts.samplers_in_dropdown: + with FormRow(elem_id=f"sampler_selection_{tabname}"): + sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + else: + with FormGroup(elem_id=f"sampler_selection_{tabname}"): + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + + return steps, sampler_index + + +def ordered_ui_categories(): + user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} + + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): + yield category + + +def create_ui(): + import modules.img2img + import modules.txt2img + + reload_javascript() + + parameters_copypaste.reset() + + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + + dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Row(elem_id='txt2img_progress_row'): + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="txt2img_progressbar") + txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) + setup_progressbar(progressbar, txt2img_preview, 'txt2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel', elem_id="txt2img_settings"): + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "cfg": + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + elif category == "checkboxes": + with FormRow(elem_id="txt2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) + + elif category == "hires_fix": + with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: + with FormRow(elem_id="txt2img_hires_fix_row1"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + with FormRow(elem_id="txt2img_hires_fix_row2"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + for input in hr_resolution_preview_inputs: + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) + + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + txt2img_args = dict( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), + _js="submit", + inputs=[ + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + ] + custom_inputs, + + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + txt2img_prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + + enable_hr.change( + fn=lambda x: gr_show(x), + inputs=[enable_hr], + outputs=[hr_options], + show_progress = False, + ) + + txt2img_paste_fields = [ + (txt2img_prompt, "Prompt"), + (txt2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (enable_hr, lambda d: "Denoising strength" in d), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), + *modules.scripts.scripts_txt2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) + + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) + + with gr.Row(elem_id='img2img_progress_row'): + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="img2img_progressbar") + img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) + setup_progressbar(progressbar, img2img_preview, 'img2img') + + with FormRow().style(equal_height=False): + with gr.Column(variant='panel', elem_id="img2img_settings"): + + with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) + init_img_with_mask_orig = gr.State(None) + + use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" + if use_color_sketch: + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) + + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") + + with FormRow(): + mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "cfg": + with FormGroup(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + + elif category == "checkboxes": + with FormRow(elem_id="img2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() + + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + + mask_mode.change( + lambda mode, img: { + init_img_with_mask: gr_show(mode == 0), + init_img_inpaint: gr_show(mode == 1), + init_mask_inpaint: gr_show(mode == 1), + }, + inputs=[mask_mode, init_img_with_mask], + outputs=[ + init_img_with_mask, + init_img_inpaint, + init_mask_inpaint, + ], + ) + + img2img_args = dict( + fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), + _js="submit_img2img", + inputs=[ + dummy_component, + img2img_prompt, + img2img_negative_prompt, + img2img_prompt_style, + img2img_prompt_style2, + init_img, + init_img_with_mask, + init_img_with_mask_orig, + init_img_inpaint, + init_mask_inpaint, + mask_mode, + steps, + sampler_index, + mask_blur, + mask_alpha, + inpainting_fill, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + denoising_strength, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + ] + custom_inputs, + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + img2img_prompt.submit(**img2img_args) + submit.click(**img2img_args) + + img2img_interrogate.click( + fn=interrogate, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] + style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] + + for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): + button.click( + fn=add_style, + _js="ask_for_style_name", + # Have to pass empty dummy component here, because the JavaScript and Python function have to accept + # the same number of parameters, but we only know the style-name after the JavaScript prompt + inputs=[dummy_component, prompt, negative_prompt], + outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], + ) + + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + button.click( + fn=apply_styles, + _js=js_func, + inputs=[prompt, negative_prompt, style1, style2], + outputs=[prompt, negative_prompt, style1, style2], + ) + + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + + img2img_paste_fields = [ + (img2img_prompt, "Prompt"), + (img2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (mask_blur, "Mask blur"), + *modules.scripts.scripts_img2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + + modules.scripts.scripts_current = None + + with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab"): + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): + with gr.Group(): + with gr.Row(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with gr.Group(): + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + + with gr.Group(): + extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") + + with gr.Group(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") + + with gr.Group(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") + + with gr.Group(): + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") + + result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) + + submit.click( + fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), + _js="get_extras_tab_index", + inputs=[ + dummy_component, + dummy_component, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + gfpgan_visibility, + codeformer_visibility, + codeformer_weight, + upscaling_resize, + upscaling_resize_w, + upscaling_resize_h, + upscaling_crop, + extras_upscaler_1, + extras_upscaler_2, + extras_upscaler_2_visibility, + upscale_before_face_fix, + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=modules.extras.clear_cache, + inputs=[], outputs=[] + ) + + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + with gr.Column(variant='panel'): + html = gr.HTML() + generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") + html2 = gr.HTML() + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + parameters_copypaste.bind_buttons(buttons, image, generation_info) + + image.change( + fn=wrap_gradio_call(modules.extras.run_pnginfo), + inputs=[image], + outputs=[html, generation_info, html2], + ) + + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + + with gr.Row(): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + + custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + + with gr.Row(): + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + + with gr.Column(variant='panel'): + submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + + with gr.Blocks(analytics_enabled=False) as train_interface: + with gr.Row().style(equal_height=False): + gr.HTML(value="

See wiki for detailed explanation.

") + + with gr.Row().style(equal_height=False): + with gr.Tabs(elem_id="train_tabs"): + + with gr.Tab(label="Create embedding"): + new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") + + with gr.Tab(label="Create hypernetwork"): + new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") + + with gr.Tab(label="Preprocess images"): + process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") + + with gr.Row(): + process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") + + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") + + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") + + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + + with gr.Tab(label="Train"): + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + with FormRow(): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + + with FormRow(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + with FormRow(): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) + + with FormRow(): + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") + steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + + with FormRow(): + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") + + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") + + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") + + with gr.Row(): + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") + interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + overwrite_old_hypernetwork, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_width, + process_height, + preprocess_txt_action, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + embedding_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + elem_id = "setting_"+key + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + components = [] + component_dict = {} + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + result = gr.HTML(elem_id="settings_result") + + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} + + quicksettings_list = [] + + previous_section = None + current_tab = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_tab.__exit__() + + current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) + current_tab.__enter__() + + previous_section = item.section + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + + if current_tab is not None: + current_tab.__exit__() + + with gr.TabItem("Actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + + if os.path.exists("html/licenses.html"): + with open("html/licenses.html", encoding="utf8") as file: + with gr.TabItem("Licenses"): + gr.HTML(file.read(), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + def request_restart(): + shared.state.interrupt() + shared.state.need_restart = True + + restart_gradio.click( + fn=request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "ti"), + ] + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" + + if os.path.exists(os.path.join(script_path, "user.css")): + with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + css += file.read() + "\n" + + if not cmd_opts.no_progressbar_hiding: + css += css_hide_progressbar + + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Row(elem_id="quicksettings"): + for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + component_dict[k] = component + + parameters_copypaste.integrate_settings_paste_fields(component_dict) + parameters_copypaste.run_bind() + + with gr.Tabs(elem_id="tabs") as tabs: + for interface, label, ifid in interfaces: + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): + interface.render() + + if os.path.exists(os.path.join(script_path, "notification.mp3")): + audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + + if os.path.exists("html/footer.html"): + with open("html/footer.html", encoding="utf8") as file: + footer = file.read() + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") + + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + settings_submit.click( + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), + inputs=components, + outputs=[text_settings, result], + ) + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [getattr(opts, key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] + return results + + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + checkpoint_format, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) + + ui_config_file = cmd_opts.ui_config_file + ui_settings = {} + settings_count = len(ui_settings) + error_loading = False + + try: + if os.path.exists(ui_config_file): + with open(ui_config_file, "r", encoding="utf8") as file: + ui_settings = json.load(file) + except Exception: + error_loading = True + print("Error loading settings:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def loadsave(path, x): + def apply_field(obj, field, condition=None, init_field=None): + key = path + "/" + field + + if getattr(obj, 'custom_script_source', None) is not None: + key = 'customscript/' + obj.custom_script_source + '/' + key + + if getattr(obj, 'do_not_save_to_config', False): + return + + saved_value = ui_settings.get(key, None) + if saved_value is None: + ui_settings[key] = getattr(obj, field) + elif condition and not condition(saved_value): + print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + else: + setattr(obj, field, saved_value) + if init_field is not None: + init_field(saved_value) + + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: + apply_field(x, 'visible') + + if type(x) == gr.Slider: + apply_field(x, 'value') + apply_field(x, 'minimum') + apply_field(x, 'maximum') + apply_field(x, 'step') + + if type(x) == gr.Radio: + apply_field(x, 'value', lambda val: val in x.choices) + + if type(x) == gr.Checkbox: + apply_field(x, 'value') + + if type(x) == gr.Textbox: + apply_field(x, 'value') + + if type(x) == gr.Number: + apply_field(x, 'value') + + if type(x) == gr.Dropdown: + apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + + visit(txt2img_interface, loadsave, "txt2img") + visit(img2img_interface, loadsave, "img2img") + visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") + visit(train_interface, loadsave, "train") + + if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): + with open(ui_config_file, "w", encoding="utf8") as file: + json.dump(ui_settings, file, indent=4) + + return demo + + +def reload_javascript(): + with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: + javascript = f'' + + scripts_list = modules.scripts.list_scripts("javascript", ".js") + + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: + javascript += f"\n" + + if cmd_opts.theme is not None: + javascript += f"\n\n" + + javascript += f"\n" + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gradio.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + + +def versions_html(): + import torch + import launch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = launch.commit_hash() + short_commit = commit[0:8] + + if shared.xformers_available: + import xformers + xformers_version = xformers.__version__ + else: + xformers_version = "N/A" + + return f""" +python: {python_version} + •  +torch: {torch.__version__} + •  +xformers: {xformers_version} + •  +gradio: {gr.__version__} + •  +commit: {short_commit} +""" diff --git a/modules/ui.py b/modules/ui.py deleted file mode 100644 index 9b9081b5..00000000 --- a/modules/ui.py +++ /dev/null @@ -1,1928 +0,0 @@ -import html -import json -import math -import mimetypes -import os -import platform -import random -import subprocess as sp -import sys -import tempfile -import time -import traceback -from functools import partial, reduce - -import gradio as gr -import gradio.routes -import gradio.utils -import numpy as np -from PIL import Image, PngImagePlugin -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call - -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path - -from modules.shared import opts, cmd_opts, restricted_opts - -import modules.codeformer_model -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.scripts -import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui -from modules import prompt_parser -from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.textual_inversion import textual_inversion -import modules.hypernetworks.ui -from modules.generation_parameters_copypaste import image_from_url_text - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -if not cmd_opts.share and not cmd_opts.listen: - # fix gradio phoning home - gradio.utils.version_check = lambda: None - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' - -if cmd_opts.ngrok is not None: - import modules.ngrok as ngrok - print('ngrok authtoken detected, trying to connect...') - ngrok.connect( - cmd_opts.ngrok, - cmd_opts.port if cmd_opts.port is not None else 7860, - cmd_opts.ngrok_region - ) - - -def gr_show(visible=True): - return {"visible": visible, "__type__": "update"} - - -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - -css_hide_progressbar = """ -.wrap .m-12 svg { display:none!important; } -.wrap .m-12::before { content:"Loading..." } -.wrap .z-20 svg { display:none!important; } -.wrap .z-20::before { content:"Loading..." } -.progress-bar { display:none!important; } -.meta-text { display:none!important; } -.meta-text-center { display:none!important; } -""" - -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -random_symbol = '\U0001f3b2\ufe0f' # 🎲️ -reuse_symbol = '\u267b\ufe0f' # ♻️ -paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ - - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text - -def send_gradio_gallery_to_image(x): - if len(x) == 0: - return None - return image_from_url_text(x[0]) - -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr_show(False) - preview_visibility = gr_show(False) - - if opts.show_progress_every_n_steps != 0: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr_show(True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr_show(False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def visit(x, func, path=""): - if hasattr(x, 'children'): - for c in x.children: - visit(c, func, path) - elif x.label is not None: - func(path + "/" + str(x.label), x) - - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] - - -def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): - from modules import processing, devices - - if not enable: - return "" - - p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) - - with devices.autocast(): - p.init([""], [0], [0]) - - return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" - - -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] - - -def interrogate(image): - prompt = shared.interrogator.interrogate(image.convert("RGB")) - - return gr_show(True) if prompt is None else prompt - - -def interrogate_deepbooru(image): - prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt - - -def create_seed_inputs(target_interface): - with FormRow(elem_id=target_interface + '_seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - - with gr.Group(elem_id=target_interface + '_subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - - with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: - seed_extras.append(seed_extra_row_1) - subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') - - with FormRow(visible=False) as seed_extra_row_2: - seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') - - random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) - random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - - def change_visibility(show): - return {comp: gr_show(show) for comp in seed_extras} - - seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) - - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - - - -def connect_clear_prompt(button): - """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" - button.click( - _js="clear_prompt", - fn=None, - inputs=[], - outputs=[], - ) - - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): - """ Connects a 'reuse (sub)seed' button's click event so that it copies last used - (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength - was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): - res = -1 - - try: - gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError as e: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) - - return [res, gr_show(False)] - - reuse_seed.click( - fn=copy_seed, - _js="(x, y) => [x, selected_gallery_index()]", - show_progress=False, - inputs=[generation_info, dummy_component], - outputs=[seed, dummy_component] - ) - - -def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" - - -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1): - with gr.Row(): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - -def apply_setting(key, value): - if value is None: - return gr.update() - - if shared.cmd_opts.freeze_settings: - return gr.update() - - # dont allow model to be swapped when model hash exists in prompt - if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: - return gr.update() - - if key == "sd_model_checkpoint": - ckpt_info = sd_models.get_closet_checkpoint_match(value) - - if ckpt_info is not None: - value = ckpt_info.title - else: - return gr.update() - - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - return - - valtype = type(opts.data_labels[key].default) - oldval = opts.data.get(key, None) - opts.data[key] = valtype(value) if valtype != type(None) else value - if oldval != value and opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - - opts.save(shared.config_filename) - return value - - -def update_generation_info(args): - generation_info, html_info, img_index = args - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info - - -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - -def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log - - -def create_sampler_and_steps_selection(choices, tabname): - if opts.samplers_in_dropdown: - with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - else: - with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - - return steps, sampler_index - - -def ordered_ui_categories(): - user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} - - for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): - yield category - - -def create_ui(): - import modules.img2img - import modules.txt2img - - reload_javascript() - - parameters_copypaste.reset() - - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - - with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - - dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel', elem_id="txt2img_settings"): - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "cfg": - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - elif category == "checkboxes": - with FormRow(elem_id="txt2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) - - elif category == "hires_fix": - with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: - with FormRow(elem_id="txt2img_hires_fix_row1"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - - with FormRow(elem_id="txt2img_hires_fix_row2"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - - hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - for input in hr_resolution_preview_inputs: - input.change( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False, - ) - input.change( - None, - _js="onCalcResolutionHires", - inputs=hr_resolution_preview_inputs, - outputs=[], - show_progress=False, - ) - - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), - _js="submit", - inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ] - ) - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - show_progress = False, - ) - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - *modules.scripts.scripts_txt2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) - - txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, - steps, - sampler_index, - cfg_scale, - seed, - width, - height, - ] - - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') - - with FormRow().style(equal_height=False): - with gr.Column(variant='panel', elem_id="img2img_settings"): - - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) - - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) - - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "cfg": - with FormGroup(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - - elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), - _js="submit_img2img", - inputs=[ - dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, - init_img, - init_img_with_mask, - init_img_with_mask_orig, - init_img_inpaint, - init_mask_inpaint, - mask_mode, - steps, - sampler_index, - mask_blur, - mask_alpha, - inpainting_fill, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) - - img2img_interrogate.click( - fn=interrogate, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], - ) - - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - - img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - - modules.scripts.scripts_current = None - - with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) - - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") - - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") - - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") - - custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - - with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") - save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row().style(equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") - initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") - new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - def get_textual_inversion_template_names(): - return sorted([x for x in textual_inversion.textual_inversion_templates]) - - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with FormRow(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - - with FormRow(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - - with FormRow(): - clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) - - with FormRow(): - batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") - - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - - with FormRow(): - template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) - create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") - - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") - steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") - - with FormRow(): - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") - - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") - - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") - - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") - - with gr.Row(): - train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") - interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) - - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout, - new_hypernetwork_dropout_structure - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_tab.__exit__() - - current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) - current_tab.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_tab.__exit__() - - with gr.TabItem("Actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] - - css = "" - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - with open(cssfile, "r", encoding="utf8") as file: - css += file.read() + "\n" - - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - css += file.read() + "\n" - - if not cmd_opts.no_progressbar_hiding: - css += css_hide_progressbar - - interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] - - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): - for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component - - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() - - with gr.Tabs(elem_id="tabs") as tabs: - for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): - interface.render() - - if os.path.exists(os.path.join(script_path, "notification.mp3")): - audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") - - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for i, k, item in quicksettings_list: - component = component_dict[k] - - component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - ) - - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - checkpoint_format, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) - - ui_config_file = cmd_opts.ui_config_file - ui_settings = {} - settings_count = len(ui_settings) - error_loading = False - - try: - if os.path.exists(ui_config_file): - with open(ui_config_file, "r", encoding="utf8") as file: - ui_settings = json.load(file) - except Exception: - error_loading = True - print("Error loading settings:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - def loadsave(path, x): - def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field - - if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key - - if getattr(obj, 'do_not_save_to_config', False): - return - - saved_value = ui_settings.get(key, None) - if saved_value is None: - ui_settings[key] = getattr(obj, field) - elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') - else: - setattr(obj, field, saved_value) - if init_field is not None: - init_field(saved_value) - - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: - apply_field(x, 'visible') - - if type(x) == gr.Slider: - apply_field(x, 'value') - apply_field(x, 'minimum') - apply_field(x, 'maximum') - apply_field(x, 'step') - - if type(x) == gr.Radio: - apply_field(x, 'value', lambda val: val in x.choices) - - if type(x) == gr.Checkbox: - apply_field(x, 'value') - - if type(x) == gr.Textbox: - apply_field(x, 'value') - - if type(x) == gr.Number: - apply_field(x, 'value') - - if type(x) == gr.Dropdown: - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - - visit(txt2img_interface, loadsave, "txt2img") - visit(img2img_interface, loadsave, "img2img") - visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") - visit(train_interface, loadsave, "train") - - if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): - with open(ui_config_file, "w", encoding="utf8") as file: - json.dump(ui_settings, file, indent=4) - - return demo - - -def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" - - if cmd_opts.theme is not None: - javascript += f"\n\n" - - javascript += f"\n" - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - -def versions_html(): - import torch - import launch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = launch.commit_hash() - short_commit = commit[0:8] - - if shared.xformers_available: - import xformers - xformers_version = xformers.__version__ - else: - xformers_version = "N/A" - - return f""" -python: {python_version} - •  -torch: {torch.__version__} - •  -xformers: {xformers_version} - •  -gradio: {gr.__version__} - •  -commit: {short_commit} -""" -- cgit v1.2.3 From 54dd5d6634ead25311a8bea0484675607601a21a Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Tue, 10 Jan 2023 11:54:49 +0300 Subject: Split history ui.py to ui_progress.py --- modules/temp | 1928 --------------------------------------------------------- modules/ui.py | 1928 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1928 insertions(+), 1928 deletions(-) delete mode 100644 modules/temp create mode 100644 modules/ui.py diff --git a/modules/temp b/modules/temp deleted file mode 100644 index 9b9081b5..00000000 --- a/modules/temp +++ /dev/null @@ -1,1928 +0,0 @@ -import html -import json -import math -import mimetypes -import os -import platform -import random -import subprocess as sp -import sys -import tempfile -import time -import traceback -from functools import partial, reduce - -import gradio as gr -import gradio.routes -import gradio.utils -import numpy as np -from PIL import Image, PngImagePlugin -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call - -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path - -from modules.shared import opts, cmd_opts, restricted_opts - -import modules.codeformer_model -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.scripts -import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui -from modules import prompt_parser -from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.textual_inversion import textual_inversion -import modules.hypernetworks.ui -from modules.generation_parameters_copypaste import image_from_url_text - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -if not cmd_opts.share and not cmd_opts.listen: - # fix gradio phoning home - gradio.utils.version_check = lambda: None - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' - -if cmd_opts.ngrok is not None: - import modules.ngrok as ngrok - print('ngrok authtoken detected, trying to connect...') - ngrok.connect( - cmd_opts.ngrok, - cmd_opts.port if cmd_opts.port is not None else 7860, - cmd_opts.ngrok_region - ) - - -def gr_show(visible=True): - return {"visible": visible, "__type__": "update"} - - -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - -css_hide_progressbar = """ -.wrap .m-12 svg { display:none!important; } -.wrap .m-12::before { content:"Loading..." } -.wrap .z-20 svg { display:none!important; } -.wrap .z-20::before { content:"Loading..." } -.progress-bar { display:none!important; } -.meta-text { display:none!important; } -.meta-text-center { display:none!important; } -""" - -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -random_symbol = '\U0001f3b2\ufe0f' # 🎲️ -reuse_symbol = '\u267b\ufe0f' # ♻️ -paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ - - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text - -def send_gradio_gallery_to_image(x): - if len(x) == 0: - return None - return image_from_url_text(x[0]) - -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr_show(False) - preview_visibility = gr_show(False) - - if opts.show_progress_every_n_steps != 0: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr_show(True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr_show(False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def visit(x, func, path=""): - if hasattr(x, 'children'): - for c in x.children: - visit(c, func, path) - elif x.label is not None: - func(path + "/" + str(x.label), x) - - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] - - -def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): - from modules import processing, devices - - if not enable: - return "" - - p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) - - with devices.autocast(): - p.init([""], [0], [0]) - - return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" - - -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] - - -def interrogate(image): - prompt = shared.interrogator.interrogate(image.convert("RGB")) - - return gr_show(True) if prompt is None else prompt - - -def interrogate_deepbooru(image): - prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt - - -def create_seed_inputs(target_interface): - with FormRow(elem_id=target_interface + '_seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - - with gr.Group(elem_id=target_interface + '_subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - - with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: - seed_extras.append(seed_extra_row_1) - subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') - - with FormRow(visible=False) as seed_extra_row_2: - seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') - - random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) - random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - - def change_visibility(show): - return {comp: gr_show(show) for comp in seed_extras} - - seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) - - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - - - -def connect_clear_prompt(button): - """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" - button.click( - _js="clear_prompt", - fn=None, - inputs=[], - outputs=[], - ) - - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): - """ Connects a 'reuse (sub)seed' button's click event so that it copies last used - (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength - was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): - res = -1 - - try: - gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError as e: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) - - return [res, gr_show(False)] - - reuse_seed.click( - fn=copy_seed, - _js="(x, y) => [x, selected_gallery_index()]", - show_progress=False, - inputs=[generation_info, dummy_component], - outputs=[seed, dummy_component] - ) - - -def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" - - -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1): - with gr.Row(): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - -def apply_setting(key, value): - if value is None: - return gr.update() - - if shared.cmd_opts.freeze_settings: - return gr.update() - - # dont allow model to be swapped when model hash exists in prompt - if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: - return gr.update() - - if key == "sd_model_checkpoint": - ckpt_info = sd_models.get_closet_checkpoint_match(value) - - if ckpt_info is not None: - value = ckpt_info.title - else: - return gr.update() - - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - return - - valtype = type(opts.data_labels[key].default) - oldval = opts.data.get(key, None) - opts.data[key] = valtype(value) if valtype != type(None) else value - if oldval != value and opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - - opts.save(shared.config_filename) - return value - - -def update_generation_info(args): - generation_info, html_info, img_index = args - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info - - -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - -def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log - - -def create_sampler_and_steps_selection(choices, tabname): - if opts.samplers_in_dropdown: - with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - else: - with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - - return steps, sampler_index - - -def ordered_ui_categories(): - user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} - - for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): - yield category - - -def create_ui(): - import modules.img2img - import modules.txt2img - - reload_javascript() - - parameters_copypaste.reset() - - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - - with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - - dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel', elem_id="txt2img_settings"): - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "cfg": - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - elif category == "checkboxes": - with FormRow(elem_id="txt2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) - - elif category == "hires_fix": - with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: - with FormRow(elem_id="txt2img_hires_fix_row1"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - - with FormRow(elem_id="txt2img_hires_fix_row2"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - - hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - for input in hr_resolution_preview_inputs: - input.change( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False, - ) - input.change( - None, - _js="onCalcResolutionHires", - inputs=hr_resolution_preview_inputs, - outputs=[], - show_progress=False, - ) - - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), - _js="submit", - inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ] - ) - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - show_progress = False, - ) - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - *modules.scripts.scripts_txt2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) - - txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, - steps, - sampler_index, - cfg_scale, - seed, - width, - height, - ] - - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') - - with FormRow().style(equal_height=False): - with gr.Column(variant='panel', elem_id="img2img_settings"): - - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) - - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) - - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "cfg": - with FormGroup(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - - elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), - _js="submit_img2img", - inputs=[ - dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, - init_img, - init_img_with_mask, - init_img_with_mask_orig, - init_img_inpaint, - init_mask_inpaint, - mask_mode, - steps, - sampler_index, - mask_blur, - mask_alpha, - inpainting_fill, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) - - img2img_interrogate.click( - fn=interrogate, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], - ) - - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - - img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - - modules.scripts.scripts_current = None - - with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) - - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") - - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") - - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") - - custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - - with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") - save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row().style(equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") - initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") - new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - def get_textual_inversion_template_names(): - return sorted([x for x in textual_inversion.textual_inversion_templates]) - - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with FormRow(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - - with FormRow(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - - with FormRow(): - clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) - - with FormRow(): - batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") - - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - - with FormRow(): - template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) - create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") - - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") - steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") - - with FormRow(): - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") - - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") - - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") - - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") - - with gr.Row(): - train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") - interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) - - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout, - new_hypernetwork_dropout_structure - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_tab.__exit__() - - current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) - current_tab.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_tab.__exit__() - - with gr.TabItem("Actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] - - css = "" - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - with open(cssfile, "r", encoding="utf8") as file: - css += file.read() + "\n" - - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - css += file.read() + "\n" - - if not cmd_opts.no_progressbar_hiding: - css += css_hide_progressbar - - interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] - - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): - for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component - - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() - - with gr.Tabs(elem_id="tabs") as tabs: - for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): - interface.render() - - if os.path.exists(os.path.join(script_path, "notification.mp3")): - audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") - - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for i, k, item in quicksettings_list: - component = component_dict[k] - - component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - ) - - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - checkpoint_format, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) - - ui_config_file = cmd_opts.ui_config_file - ui_settings = {} - settings_count = len(ui_settings) - error_loading = False - - try: - if os.path.exists(ui_config_file): - with open(ui_config_file, "r", encoding="utf8") as file: - ui_settings = json.load(file) - except Exception: - error_loading = True - print("Error loading settings:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - def loadsave(path, x): - def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field - - if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key - - if getattr(obj, 'do_not_save_to_config', False): - return - - saved_value = ui_settings.get(key, None) - if saved_value is None: - ui_settings[key] = getattr(obj, field) - elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') - else: - setattr(obj, field, saved_value) - if init_field is not None: - init_field(saved_value) - - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: - apply_field(x, 'visible') - - if type(x) == gr.Slider: - apply_field(x, 'value') - apply_field(x, 'minimum') - apply_field(x, 'maximum') - apply_field(x, 'step') - - if type(x) == gr.Radio: - apply_field(x, 'value', lambda val: val in x.choices) - - if type(x) == gr.Checkbox: - apply_field(x, 'value') - - if type(x) == gr.Textbox: - apply_field(x, 'value') - - if type(x) == gr.Number: - apply_field(x, 'value') - - if type(x) == gr.Dropdown: - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - - visit(txt2img_interface, loadsave, "txt2img") - visit(img2img_interface, loadsave, "img2img") - visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") - visit(train_interface, loadsave, "train") - - if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): - with open(ui_config_file, "w", encoding="utf8") as file: - json.dump(ui_settings, file, indent=4) - - return demo - - -def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" - - if cmd_opts.theme is not None: - javascript += f"\n\n" - - javascript += f"\n" - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - -def versions_html(): - import torch - import launch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = launch.commit_hash() - short_commit = commit[0:8] - - if shared.xformers_available: - import xformers - xformers_version = xformers.__version__ - else: - xformers_version = "N/A" - - return f""" -python: {python_version} - •  -torch: {torch.__version__} - •  -xformers: {xformers_version} - •  -gradio: {gr.__version__} - •  -commit: {short_commit} -""" diff --git a/modules/ui.py b/modules/ui.py new file mode 100644 index 00000000..9b9081b5 --- /dev/null +++ b/modules/ui.py @@ -0,0 +1,1928 @@ +import html +import json +import math +import mimetypes +import os +import platform +import random +import subprocess as sp +import sys +import tempfile +import time +import traceback +from functools import partial, reduce + +import gradio as gr +import gradio.routes +import gradio.utils +import numpy as np +from PIL import Image, PngImagePlugin +from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call + +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML +from modules.paths import script_path + +from modules.shared import opts, cmd_opts, restricted_opts + +import modules.codeformer_model +import modules.generation_parameters_copypaste as parameters_copypaste +import modules.gfpgan_model +import modules.hypernetworks.ui +import modules.scripts +import modules.shared as shared +import modules.styles +import modules.textual_inversion.ui +from modules import prompt_parser +from modules.images import save_image +from modules.sd_hijack import model_hijack +from modules.sd_samplers import samplers, samplers_for_img2img +from modules.textual_inversion import textual_inversion +import modules.hypernetworks.ui +from modules.generation_parameters_copypaste import image_from_url_text + +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI +mimetypes.init() +mimetypes.add_type('application/javascript', '.js') + +if not cmd_opts.share and not cmd_opts.listen: + # fix gradio phoning home + gradio.utils.version_check = lambda: None + gradio.utils.get_local_ip_address = lambda: '127.0.0.1' + +if cmd_opts.ngrok is not None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect( + cmd_opts.ngrok, + cmd_opts.port if cmd_opts.port is not None else 7860, + cmd_opts.ngrok_region + ) + + +def gr_show(visible=True): + return {"visible": visible, "__type__": "update"} + + +sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None + +css_hide_progressbar = """ +.wrap .m-12 svg { display:none!important; } +.wrap .m-12::before { content:"Loading..." } +.wrap .z-20 svg { display:none!important; } +.wrap .z-20::before { content:"Loading..." } +.progress-bar { display:none!important; } +.meta-text { display:none!important; } +.meta-text-center { display:none!important; } +""" + +# Using constants for these since the variation selector isn't visible. +# Important that they exactly match script.js for tooltip to work. +random_symbol = '\U0001f3b2\ufe0f' # 🎲️ +reuse_symbol = '\u267b\ufe0f' # ♻️ +paste_symbol = '\u2199\ufe0f' # ↙ +folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 +save_style_symbol = '\U0001f4be' # 💾 +apply_style_symbol = '\U0001f4cb' # 📋 +clear_prompt_symbol = '\U0001F5D1' # 🗑️ + + +def plaintext_to_html(text): + text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" + return text + +def send_gradio_gallery_to_image(x): + if len(x) == 0: + return None + return image_from_url_text(x[0]) + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = opts.outdir_save + save_to_dirs = opts.use_save_to_dirs_for_ui + extension: str = opts.samples_format + start_index = 0 + + if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(opts.outdir_save, exist_ok=True) + + with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def calc_time_left(progress, threshold, label, force_display, show_eta): + if progress == 0: + return "" + else: + time_since_start = time.time() - shared.state.time_start + eta = (time_since_start/progress) + eta_relative = eta-time_since_start + if (eta_relative > threshold and show_eta) or force_display: + if eta_relative > 3600: + return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) + elif eta_relative > 60: + return label + time.strftime('%M:%S', time.gmtime(eta_relative)) + else: + return label + time.strftime('%Ss', time.gmtime(eta_relative)) + else: + return "" + + +def check_progress_call(id_part): + if shared.state.job_count == 0: + return "", gr_show(False), gr_show(False), gr_show(False) + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + # Show progress percentage and time left at the same moment, and base it also on steps done + show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 + + time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) + if time_left != "": + shared.state.time_left_force_display = True + + progress = min(progress, 1) + + progressbar = "" + if opts.show_progressbar: + progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" + + image = gr_show(False) + preview_visibility = gr_show(False) + + if opts.show_progress_every_n_steps != 0: + shared.state.set_current_image() + image = shared.state.current_image + + if image is None: + image = gr.update(value=None) + else: + preview_visibility = gr_show(True) + + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result + + +def check_progress_call_initial(id_part): + shared.state.job_count = -1 + shared.state.current_latent = None + shared.state.current_image = None + shared.state.textinfo = None + shared.state.time_start = time.time() + shared.state.time_left_force_display = False + + return check_progress_call(id_part) + + +def visit(x, func, path=""): + if hasattr(x, 'children'): + for c in x.children: + visit(c, func, path) + elif x.label is not None: + func(path + "/" + str(x.label), x) + + +def add_style(name: str, prompt: str, negative_prompt: str): + if name is None: + return [gr_show() for x in range(4)] + + style = modules.styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we + # reserialize all styles every time we save them + shared.prompt_styles.save_styles(shared.styles_filename) + + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] + + +def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): + from modules import processing, devices + + if not enable: + return "" + + p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + + with devices.autocast(): + p.init([""], [0], [0]) + + return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" + + +def apply_styles(prompt, prompt_neg, style1_name, style2_name): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] + + +def interrogate(image): + prompt = shared.interrogator.interrogate(image.convert("RGB")) + + return gr_show(True) if prompt is None else prompt + + +def interrogate_deepbooru(image): + prompt = deepbooru.model.tag(image) + return gr_show(True) if prompt is None else prompt + + +def create_seed_inputs(target_interface): + with FormRow(elem_id=target_interface + '_seed_row'): + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') + seed.style(container=False) + random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') + reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') + + with gr.Group(elem_id=target_interface + '_subseed_show_box'): + seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) + + # Components to show/hide based on the 'Extra' checkbox + seed_extras = [] + + with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: + seed_extras.append(seed_extra_row_1) + subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') + subseed.style(container=False) + random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') + reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') + + with FormRow(visible=False) as seed_extra_row_2: + seed_extras.append(seed_extra_row_2) + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') + + random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) + random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) + + def change_visibility(show): + return {comp: gr_show(show) for comp in seed_extras} + + seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) + + return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox + + + +def connect_clear_prompt(button): + """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" + button.click( + _js="clear_prompt", + fn=None, + inputs=[], + outputs=[], + ) + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError as e: + if gen_info_string != '': + print("Error parsing JSON generation info:", file=sys.stderr) + print(gen_info_string, file=sys.stderr) + + return [res, gr_show(False)] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, dummy_component], + outputs=[seed, dummy_component] + ) + + +def update_token_counter(text, steps): + try: + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) + prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) + + except Exception: + # a parsing error can happen here during typing, and we don't want to bother the user with + # messages related to it in console + prompt_schedules = [[[steps, text]]] + + flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) + prompts = [prompt_text for step, prompt_text in flat_prompts] + token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" + + +def create_toprow(is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + + with gr.Row(elem_id="toprow"): + with gr.Column(scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" + ) + + with gr.Column(scale=1, elem_id="roll_col"): + paste = gr.Button(value=paste_symbol, elem_id="paste") + save_style = gr.Button(value=save_style_symbol, elem_id="style_create") + prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") + clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + + clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[prompt, negative_prompt], + outputs=[prompt, negative_prompt], + ) + + button_interrogate = None + button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_id="interrogate_col"): + button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1): + with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") + interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") + submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + + interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + with gr.Row(): + with gr.Column(scale=1, elem_id="style_pos_col"): + prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + with gr.Column(scale=1, elem_id="style_neg_col"): + prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + + return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + + +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) + check_progress.click( + fn=lambda: check_progress_call(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) + check_progress_initial.click( + fn=lambda: check_progress_call_initial(id_part), + show_progress=False, + inputs=[], + outputs=[progressbar, preview, preview, textinfo], + ) + + +def apply_setting(key, value): + if value is None: + return gr.update() + + if shared.cmd_opts.freeze_settings: + return gr.update() + + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + + if key == "sd_model_checkpoint": + ckpt_info = sd_models.get_closet_checkpoint_match(value) + + if ckpt_info is not None: + value = ckpt_info.title + else: + return gr.update() + + comp_args = opts.data_labels[key].component_args + if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: + return + + valtype = type(opts.data_labels[key].default) + oldval = opts.data.get(key, None) + opts.data[key] = valtype(value) if valtype != type(None) else value + if oldval != value and opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + return value + + +def update_generation_info(args): + generation_info, html_info, img_index = args + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info + return plaintext_to_html(generation_info["infotexts"][img_index]) + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_output_panel(tabname, outdir): + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel'): + with gr.Group(): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="(x, y) => [x, y, selected_gallery_index()]", + inputs=[generation_info, html_info], + outputs=[html_info], + preprocess=False + ) + + save.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + save_zip.click( + fn=wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + + +def create_sampler_and_steps_selection(choices, tabname): + if opts.samplers_in_dropdown: + with FormRow(elem_id=f"sampler_selection_{tabname}"): + sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + else: + with FormGroup(elem_id=f"sampler_selection_{tabname}"): + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) + sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + + return steps, sampler_index + + +def ordered_ui_categories(): + user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} + + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): + yield category + + +def create_ui(): + import modules.img2img + import modules.txt2img + + reload_javascript() + + parameters_copypaste.reset() + + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + + with gr.Blocks(analytics_enabled=False) as txt2img_interface: + txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + + dummy_component = gr.Label(visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Row(elem_id='txt2img_progress_row'): + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="txt2img_progressbar") + txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) + setup_progressbar(progressbar, txt2img_preview, 'txt2img') + + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel', elem_id="txt2img_settings"): + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="txt2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "cfg": + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') + + elif category == "checkboxes": + with FormRow(elem_id="txt2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) + + elif category == "hires_fix": + with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: + with FormRow(elem_id="txt2img_hires_fix_row1"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + with FormRow(elem_id="txt2img_hires_fix_row2"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="txt2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="txt2img_script_container"): + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] + for input in hr_resolution_preview_inputs: + input.change( + fn=calc_resolution_hires, + inputs=hr_resolution_preview_inputs, + outputs=[hr_final_resolution], + show_progress=False, + ) + input.change( + None, + _js="onCalcResolutionHires", + inputs=hr_resolution_preview_inputs, + outputs=[], + show_progress=False, + ) + + txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) + parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + txt2img_args = dict( + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), + _js="submit", + inputs=[ + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + denoising_strength, + hr_scale, + hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, + ] + custom_inputs, + + outputs=[ + txt2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + txt2img_prompt.submit(**txt2img_args) + submit.click(**txt2img_args) + + txt_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + txt_prompt_img + ], + outputs=[ + txt2img_prompt, + txt_prompt_img + ] + ) + + enable_hr.change( + fn=lambda x: gr_show(x), + inputs=[enable_hr], + outputs=[hr_options], + show_progress = False, + ) + + txt2img_paste_fields = [ + (txt2img_prompt, "Prompt"), + (txt2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (enable_hr, lambda d: "Denoising strength" in d), + (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), + *modules.scripts.scripts_txt2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) + + txt2img_preview_params = [ + txt2img_prompt, + txt2img_negative_prompt, + steps, + sampler_index, + cfg_scale, + seed, + width, + height, + ] + + token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) + + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + + with gr.Blocks(analytics_enabled=False) as img2img_interface: + img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) + + with gr.Row(elem_id='img2img_progress_row'): + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + + with gr.Column(scale=1): + pass + + with gr.Column(scale=1): + progressbar = gr.HTML(elem_id="img2img_progressbar") + img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) + setup_progressbar(progressbar, img2img_preview, 'img2img') + + with FormRow().style(equal_height=False): + with gr.Column(variant='panel', elem_id="img2img_settings"): + + with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) + + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) + init_img_with_mask_orig = gr.State(None) + + use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" + if use_color_sketch: + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state + + init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) + + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") + + with FormRow(): + mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): + hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") + img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + + with FormRow(): + resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + + for category in ordered_ui_categories(): + if category == "sampler": + steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + + elif category == "dimensions": + with FormRow(): + with gr.Column(elem_id="img2img_column_size", scale=4): + width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") + height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") + + if opts.dimensions_and_batch_together: + with gr.Column(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "cfg": + with FormGroup(): + cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") + + elif category == "seed": + seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') + + elif category == "checkboxes": + with FormRow(elem_id="img2img_checkboxes"): + restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") + tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") + + elif category == "batch": + if not opts.dimensions_and_batch_together: + with FormRow(elem_id="img2img_column_batch"): + batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") + batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") + + elif category == "scripts": + with FormGroup(elem_id="img2img_script_container"): + custom_inputs = modules.scripts.scripts_img2img.setup_ui() + + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) + parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) + + connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) + connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) + + img2img_prompt_img.change( + fn=modules.images.image_data, + inputs=[ + img2img_prompt_img + ], + outputs=[ + img2img_prompt, + img2img_prompt_img + ] + ) + + mask_mode.change( + lambda mode, img: { + init_img_with_mask: gr_show(mode == 0), + init_img_inpaint: gr_show(mode == 1), + init_mask_inpaint: gr_show(mode == 1), + }, + inputs=[mask_mode, init_img_with_mask], + outputs=[ + init_img_with_mask, + init_img_inpaint, + init_mask_inpaint, + ], + ) + + img2img_args = dict( + fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), + _js="submit_img2img", + inputs=[ + dummy_component, + img2img_prompt, + img2img_negative_prompt, + img2img_prompt_style, + img2img_prompt_style2, + init_img, + init_img_with_mask, + init_img_with_mask_orig, + init_img_inpaint, + init_mask_inpaint, + mask_mode, + steps, + sampler_index, + mask_blur, + mask_alpha, + inpainting_fill, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + denoising_strength, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + resize_mode, + inpaint_full_res, + inpaint_full_res_padding, + inpainting_mask_invert, + img2img_batch_input_dir, + img2img_batch_output_dir, + ] + custom_inputs, + outputs=[ + img2img_gallery, + generation_info, + html_info, + html_log, + ], + show_progress=False, + ) + + img2img_prompt.submit(**img2img_args) + submit.click(**img2img_args) + + img2img_interrogate.click( + fn=interrogate, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) + + prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] + style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] + + for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): + button.click( + fn=add_style, + _js="ask_for_style_name", + # Have to pass empty dummy component here, because the JavaScript and Python function have to accept + # the same number of parameters, but we only know the style-name after the JavaScript prompt + inputs=[dummy_component, prompt, negative_prompt], + outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], + ) + + for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + button.click( + fn=apply_styles, + _js=js_func, + inputs=[prompt, negative_prompt, style1, style2], + outputs=[prompt, negative_prompt, style1, style2], + ) + + token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + + img2img_paste_fields = [ + (img2img_prompt, "Prompt"), + (img2img_negative_prompt, "Negative prompt"), + (steps, "Steps"), + (sampler_index, "Sampler"), + (restore_faces, "Face restoration"), + (cfg_scale, "CFG scale"), + (seed, "Seed"), + (width, "Size-1"), + (height, "Size-2"), + (batch_size, "Batch size"), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + (denoising_strength, "Denoising strength"), + (mask_blur, "Mask blur"), + *modules.scripts.scripts_img2img.infotext_fields + ] + parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) + parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + + modules.scripts.scripts_current = None + + with gr.Blocks(analytics_enabled=False) as extras_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab"): + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): + with gr.Group(): + with gr.Row(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with gr.Group(): + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + + with gr.Group(): + extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") + + with gr.Group(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") + + with gr.Group(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") + + with gr.Group(): + upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") + + result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) + + submit.click( + fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), + _js="get_extras_tab_index", + inputs=[ + dummy_component, + dummy_component, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + gfpgan_visibility, + codeformer_visibility, + codeformer_weight, + upscaling_resize, + upscaling_resize_w, + upscaling_resize_h, + upscaling_crop, + extras_upscaler_1, + extras_upscaler_2, + extras_upscaler_2_visibility, + upscale_before_face_fix, + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=modules.extras.clear_cache, + inputs=[], outputs=[] + ) + + with gr.Blocks(analytics_enabled=False) as pnginfo_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") + + with gr.Column(variant='panel'): + html = gr.HTML() + generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") + html2 = gr.HTML() + with gr.Row(): + buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) + parameters_copypaste.bind_buttons(buttons, image, generation_info) + + image.change( + fn=wrap_gradio_call(modules.extras.run_pnginfo), + inputs=[image], + outputs=[html, generation_info, html2], + ) + + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row().style(equal_height=False): + with gr.Column(variant='panel'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + + with gr.Row(): + primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + + secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + + tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + + custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + + with gr.Row(): + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + + with gr.Column(variant='panel'): + submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + + with gr.Blocks(analytics_enabled=False) as train_interface: + with gr.Row().style(equal_height=False): + gr.HTML(value="

See wiki for detailed explanation.

") + + with gr.Row().style(equal_height=False): + with gr.Tabs(elem_id="train_tabs"): + + with gr.Tab(label="Create embedding"): + new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") + initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") + overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") + + with gr.Tab(label="Create hypernetwork"): + new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") + new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") + new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") + new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") + new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") + overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") + + with gr.Tab(label="Preprocess images"): + process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") + process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") + process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") + process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") + preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") + + with gr.Row(): + process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") + process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") + process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") + + with gr.Row(visible=False) as process_split_extra_row: + process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") + process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") + + with gr.Row(visible=False) as process_focal_crop_row: + process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") + process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") + process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") + process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") + run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") + + process_split.change( + fn=lambda show: gr_show(show), + inputs=[process_split], + outputs=[process_split_extra_row], + ) + + process_focal_crop.change( + fn=lambda show: gr_show(show), + inputs=[process_focal_crop], + outputs=[process_focal_crop_row], + ) + + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + + with gr.Tab(label="Train"): + gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") + with FormRow(): + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") + + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) + create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") + + with FormRow(): + embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") + hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") + + with FormRow(): + clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) + clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) + + with FormRow(): + batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") + gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") + + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") + training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") + steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") + + with FormRow(): + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") + + save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") + preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") + + shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") + tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") + + latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") + + with gr.Row(): + train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") + interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") + train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") + + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + initialization_text, + nvpt, + overwrite_old_embedding, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + create_hypernetwork.click( + fn=modules.hypernetworks.ui.create_hypernetwork, + inputs=[ + new_hypernetwork_name, + new_hypernetwork_sizes, + overwrite_old_hypernetwork, + new_hypernetwork_layer_structure, + new_hypernetwork_activation_func, + new_hypernetwork_initialization_option, + new_hypernetwork_add_layer_norm, + new_hypernetwork_use_dropout, + new_hypernetwork_dropout_structure + ], + outputs=[ + train_hypernetwork_name, + ti_output, + ti_outcome, + ] + ) + + run_preprocess.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + process_src, + process_dst, + process_width, + process_height, + preprocess_txt_action, + process_flip, + process_split, + process_caption, + process_caption_deepbooru, + process_split_threshold, + process_overlap_ratio, + process_focal_crop, + process_focal_crop_face_weight, + process_focal_crop_entropy_weight, + process_focal_crop_edges_weight, + process_focal_crop_debug, + ], + outputs=[ + ti_output, + ti_outcome, + ], + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + embedding_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + save_image_with_stored_embedding, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + train_hypernetwork.click( + fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_hypernetwork_name, + hypernetwork_learn_rate, + batch_size, + gradient_step, + dataset_directory, + log_directory, + training_width, + training_height, + varsize, + steps, + clip_grad_mode, + clip_grad_value, + shuffle_tags, + tag_drop_out, + latent_sampling_method, + create_image_every, + save_embedding_every, + template_file, + preview_from_txt2img, + *txt2img_preview_params, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + interrupt_preprocessing.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + + def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {str(t)} for key {key}') + + elem_id = "setting_"+key + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + components = [] + component_dict = {} + + script_callbacks.ui_settings_callback() + opts.reorder() + + def run_settings(*args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + result = gr.HTML(elem_id="settings_result") + + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} + + quicksettings_list = [] + + previous_section = None + current_tab = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_tab.__exit__() + + current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) + current_tab.__enter__() + + previous_section = item.section + + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + + if current_tab is not None: + current_tab.__exit__() + + with gr.TabItem("Actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + + if os.path.exists("html/licenses.html"): + with open("html/licenses.html", encoding="utf8") as file: + with gr.TabItem("Licenses"): + gr.HTML(file.read(), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + modules.scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + def request_restart(): + shared.state.interrupt() + shared.state.need_restart = True + + restart_gradio.click( + fn=request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + interfaces = [ + (txt2img_interface, "txt2img", "txt2img"), + (img2img_interface, "img2img", "img2img"), + (extras_interface, "Extras", "extras"), + (pnginfo_interface, "PNG Info", "pnginfo"), + (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (train_interface, "Train", "ti"), + ] + + css = "" + + for cssfile in modules.scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + with open(cssfile, "r", encoding="utf8") as file: + css += file.read() + "\n" + + if os.path.exists(os.path.join(script_path, "user.css")): + with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + css += file.read() + "\n" + + if not cmd_opts.no_progressbar_hiding: + css += css_hide_progressbar + + interfaces += script_callbacks.ui_tabs_callback() + interfaces += [(settings_interface, "Settings", "settings")] + + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Row(elem_id="quicksettings"): + for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + component_dict[k] = component + + parameters_copypaste.integrate_settings_paste_fields(component_dict) + parameters_copypaste.run_bind() + + with gr.Tabs(elem_id="tabs") as tabs: + for interface, label, ifid in interfaces: + with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): + interface.render() + + if os.path.exists(os.path.join(script_path, "notification.mp3")): + audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) + + if os.path.exists("html/footer.html"): + with open("html/footer.html", encoding="utf8") as file: + footer = file.read() + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") + + text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + settings_submit.click( + fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), + inputs=components, + outputs=[text_settings, result], + ) + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [getattr(opts, key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + + def modelmerger(*args): + try: + results = modules.extras.run_modelmerger(*args) + except Exception as e: + print("Error loading/saving model file:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + modules.sd_models.list_models() # to remove the potentially missing models from the list + return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] + return results + + modelmerger_merge.click( + fn=modelmerger, + inputs=[ + primary_model_name, + secondary_model_name, + tertiary_model_name, + interp_method, + interp_amount, + save_as_half, + custom_name, + checkpoint_format, + ], + outputs=[ + submit_result, + primary_model_name, + secondary_model_name, + tertiary_model_name, + component_dict['sd_model_checkpoint'], + ] + ) + + ui_config_file = cmd_opts.ui_config_file + ui_settings = {} + settings_count = len(ui_settings) + error_loading = False + + try: + if os.path.exists(ui_config_file): + with open(ui_config_file, "r", encoding="utf8") as file: + ui_settings = json.load(file) + except Exception: + error_loading = True + print("Error loading settings:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def loadsave(path, x): + def apply_field(obj, field, condition=None, init_field=None): + key = path + "/" + field + + if getattr(obj, 'custom_script_source', None) is not None: + key = 'customscript/' + obj.custom_script_source + '/' + key + + if getattr(obj, 'do_not_save_to_config', False): + return + + saved_value = ui_settings.get(key, None) + if saved_value is None: + ui_settings[key] = getattr(obj, field) + elif condition and not condition(saved_value): + print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + else: + setattr(obj, field, saved_value) + if init_field is not None: + init_field(saved_value) + + if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: + apply_field(x, 'visible') + + if type(x) == gr.Slider: + apply_field(x, 'value') + apply_field(x, 'minimum') + apply_field(x, 'maximum') + apply_field(x, 'step') + + if type(x) == gr.Radio: + apply_field(x, 'value', lambda val: val in x.choices) + + if type(x) == gr.Checkbox: + apply_field(x, 'value') + + if type(x) == gr.Textbox: + apply_field(x, 'value') + + if type(x) == gr.Number: + apply_field(x, 'value') + + if type(x) == gr.Dropdown: + apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + + visit(txt2img_interface, loadsave, "txt2img") + visit(img2img_interface, loadsave, "img2img") + visit(extras_interface, loadsave, "extras") + visit(modelmerger_interface, loadsave, "modelmerger") + visit(train_interface, loadsave, "train") + + if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): + with open(ui_config_file, "w", encoding="utf8") as file: + json.dump(ui_settings, file, indent=4) + + return demo + + +def reload_javascript(): + with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: + javascript = f'' + + scripts_list = modules.scripts.list_scripts("javascript", ".js") + + for basedir, filename, path in scripts_list: + with open(path, "r", encoding="utf8") as jsfile: + javascript += f"\n" + + if cmd_opts.theme is not None: + javascript += f"\n\n" + + javascript += f"\n" + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace( + b'', f'{javascript}'.encode("utf8")) + res.init_headers() + return res + + gradio.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse + + +def versions_html(): + import torch + import launch + + python_version = ".".join([str(x) for x in sys.version_info[0:3]]) + commit = launch.commit_hash() + short_commit = commit[0:8] + + if shared.xformers_available: + import xformers + xformers_version = xformers.__version__ + else: + xformers_version = "N/A" + + return f""" +python: {python_version} + •  +torch: {torch.__version__} + •  +xformers: {xformers_version} + •  +gradio: {gr.__version__} + •  +commit: {short_commit} +""" -- cgit v1.2.3 From ef75c980536471c0729a2319440e3083cd57a4f0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 12:29:45 +0300 Subject: Split history ui.py to ui_progress.py --- modules/ui.py | 94 +-- modules/ui_progress.py | 1839 +----------------------------------------------- 2 files changed, 9 insertions(+), 1924 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 9b9081b5..3c458ce8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -162,79 +162,6 @@ def save_files(js_data, images, do_make_zip, index): return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr_show(False) - preview_visibility = gr_show(False) - - if opts.show_progress_every_n_steps != 0: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr_show(True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr_show(False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - def visit(x, func, path=""): if hasattr(x, 'children'): for c in x.children: @@ -456,25 +383,10 @@ def create_toprow(is_img2img): return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) +def setup_progressbar(*args, **kwargs): + import modules.ui_progress - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) + modules.ui_progress.setup_progressbar(*args, **kwargs) def apply_setting(key, value): diff --git a/modules/ui_progress.py b/modules/ui_progress.py index 9b9081b5..592fda55 100644 --- a/modules/ui_progress.py +++ b/modules/ui_progress.py @@ -1,165 +1,10 @@ -import html -import json -import math -import mimetypes -import os -import platform -import random -import subprocess as sp -import sys -import tempfile import time -import traceback -from functools import partial, reduce import gradio as gr -import gradio.routes -import gradio.utils -import numpy as np -from PIL import Image, PngImagePlugin -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru -from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path +from modules.shared import opts -from modules.shared import opts, cmd_opts, restricted_opts - -import modules.codeformer_model -import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.scripts import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui -from modules import prompt_parser -from modules.images import save_image -from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img -from modules.textual_inversion import textual_inversion -import modules.hypernetworks.ui -from modules.generation_parameters_copypaste import image_from_url_text - -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -mimetypes.init() -mimetypes.add_type('application/javascript', '.js') - -if not cmd_opts.share and not cmd_opts.listen: - # fix gradio phoning home - gradio.utils.version_check = lambda: None - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' - -if cmd_opts.ngrok is not None: - import modules.ngrok as ngrok - print('ngrok authtoken detected, trying to connect...') - ngrok.connect( - cmd_opts.ngrok, - cmd_opts.port if cmd_opts.port is not None else 7860, - cmd_opts.ngrok_region - ) - - -def gr_show(visible=True): - return {"visible": visible, "__type__": "update"} - - -sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg" -sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None - -css_hide_progressbar = """ -.wrap .m-12 svg { display:none!important; } -.wrap .m-12::before { content:"Loading..." } -.wrap .z-20 svg { display:none!important; } -.wrap .z-20::before { content:"Loading..." } -.progress-bar { display:none!important; } -.meta-text { display:none!important; } -.meta-text-center { display:none!important; } -""" - -# Using constants for these since the variation selector isn't visible. -# Important that they exactly match script.js for tooltip to work. -random_symbol = '\U0001f3b2\ufe0f' # 🎲️ -reuse_symbol = '\u267b\ufe0f' # ♻️ -paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 -refresh_symbol = '\U0001f504' # 🔄 -save_style_symbol = '\U0001f4be' # 💾 -apply_style_symbol = '\U0001f4cb' # 📋 -clear_prompt_symbol = '\U0001F5D1' # 🗑️ - - -def plaintext_to_html(text): - text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" - return text - -def send_gradio_gallery_to_image(x): - if len(x) == 0: - return None - return image_from_url_text(x[0]) - -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") def calc_time_left(progress, threshold, label, force_display, show_eta): @@ -182,7 +27,7 @@ def calc_time_left(progress, threshold, label, force_display, show_eta): def check_progress_call(id_part): if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False), gr_show(False) + return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) progress = 0 @@ -204,8 +49,8 @@ def check_progress_call(id_part): if opts.show_progressbar: progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - image = gr_show(False) - preview_visibility = gr_show(False) + image = gr.update(visible=False) + preview_visibility = gr.update(visible=False) if opts.show_progress_every_n_steps != 0: shared.state.set_current_image() @@ -214,12 +59,12 @@ def check_progress_call(id_part): if image is None: image = gr.update(value=None) else: - preview_visibility = gr_show(True) + preview_visibility = gr.update(visible=True) if shared.state.textinfo is not None: textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) else: - textinfo_result = gr_show(False) + textinfo_result = gr.update(visible=False) return f"

{progressbar}

", preview_visibility, image, textinfo_result @@ -235,227 +80,6 @@ def check_progress_call_initial(id_part): return check_progress_call(id_part) -def visit(x, func, path=""): - if hasattr(x, 'children'): - for c in x.children: - visit(c, func, path) - elif x.label is not None: - func(path + "/" + str(x.label), x) - - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] - - -def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): - from modules import processing, devices - - if not enable: - return "" - - p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) - - with devices.autocast(): - p.init([""], [0], [0]) - - return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" - - -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] - - -def interrogate(image): - prompt = shared.interrogator.interrogate(image.convert("RGB")) - - return gr_show(True) if prompt is None else prompt - - -def interrogate_deepbooru(image): - prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt - - -def create_seed_inputs(target_interface): - with FormRow(elem_id=target_interface + '_seed_row'): - seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=target_interface + '_seed') - seed.style(container=False) - random_seed = gr.Button(random_symbol, elem_id=target_interface + '_random_seed') - reuse_seed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_seed') - - with gr.Group(elem_id=target_interface + '_subseed_show_box'): - seed_checkbox = gr.Checkbox(label='Extra', elem_id=target_interface + '_subseed_show', value=False) - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - - with FormRow(visible=False, elem_id=target_interface + '_subseed_row') as seed_extra_row_1: - seed_extras.append(seed_extra_row_1) - subseed = gr.Number(label='Variation seed', value=-1, elem_id=target_interface + '_subseed') - subseed.style(container=False) - random_subseed = gr.Button(random_symbol, elem_id=target_interface + '_random_subseed') - reuse_subseed = gr.Button(reuse_symbol, elem_id=target_interface + '_reuse_subseed') - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=target_interface + '_subseed_strength') - - with FormRow(visible=False) as seed_extra_row_2: - seed_extras.append(seed_extra_row_2) - seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=target_interface + '_seed_resize_from_w') - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=target_interface + '_seed_resize_from_h') - - random_seed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[seed]) - random_subseed.click(fn=lambda: -1, show_progress=False, inputs=[], outputs=[subseed]) - - def change_visibility(show): - return {comp: gr_show(show) for comp in seed_extras} - - seed_checkbox.change(change_visibility, show_progress=False, inputs=[seed_checkbox], outputs=seed_extras) - - return seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox - - - -def connect_clear_prompt(button): - """Given clear button, prompt, and token_counter objects, setup clear prompt button click event""" - button.click( - _js="clear_prompt", - fn=None, - inputs=[], - outputs=[], - ) - - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): - """ Connects a 'reuse (sub)seed' button's click event so that it copies last used - (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength - was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" - def copy_seed(gen_info_string: str, index): - res = -1 - - try: - gen_info = json.loads(gen_info_string) - index -= gen_info.get('index_of_first_image', 0) - - if is_subseed and gen_info.get('subseed_strength', 0) > 0: - all_subseeds = gen_info.get('all_subseeds', [-1]) - res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] - else: - all_seeds = gen_info.get('all_seeds', [-1]) - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] - - except json.decoder.JSONDecodeError as e: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) - - return [res, gr_show(False)] - - reuse_seed.click( - fn=copy_seed, - _js="(x, y) => [x, selected_gallery_index()]", - show_progress=False, - inputs=[generation_info, dummy_component], - outputs=[seed, dummy_component] - ) - - -def update_token_counter(text, steps): - try: - _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) - prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) - - except Exception: - # a parsing error can happen here during typing, and we don't want to bother the user with - # messages related to it in console - prompt_schedules = [[[steps, text]]] - - flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) - prompts = [prompt_text for step, prompt_text in flat_prompts] - token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" - - -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" - - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) - - with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_id="interrogate_col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1): - with gr.Row(): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) - - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button - - def setup_progressbar(progressbar, preview, id_part, textinfo=None): if textinfo is None: textinfo = gr.HTML(visible=False) @@ -475,1454 +99,3 @@ def setup_progressbar(progressbar, preview, id_part, textinfo=None): inputs=[], outputs=[progressbar, preview, preview, textinfo], ) - - -def apply_setting(key, value): - if value is None: - return gr.update() - - if shared.cmd_opts.freeze_settings: - return gr.update() - - # dont allow model to be swapped when model hash exists in prompt - if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: - return gr.update() - - if key == "sd_model_checkpoint": - ckpt_info = sd_models.get_closet_checkpoint_match(value) - - if ckpt_info is not None: - value = ckpt_info.title - else: - return gr.update() - - comp_args = opts.data_labels[key].component_args - if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: - return - - valtype = type(opts.data_labels[key].default) - oldval = opts.data.get(key, None) - opts.data[key] = valtype(value) if valtype != type(None) else value - if oldval != value and opts.data_labels[key].onchange is not None: - opts.data_labels[key].onchange() - - opts.save(shared.config_filename) - return value - - -def update_generation_info(args): - generation_info, html_info, img_index = args - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info - - -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - -def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel'): - with gr.Group(): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log - - -def create_sampler_and_steps_selection(choices, tabname): - if opts.samplers_in_dropdown: - with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - else: - with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") - - return steps, sampler_index - - -def ordered_ui_categories(): - user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} - - for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): - yield category - - -def create_ui(): - import modules.img2img - import modules.txt2img - - reload_javascript() - - parameters_copypaste.reset() - - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) - - with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) - - dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel', elem_id="txt2img_settings"): - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="txt2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="txt2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "cfg": - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') - - elif category == "checkboxes": - with FormRow(elem_id="txt2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False) - - elif category == "hires_fix": - with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: - with FormRow(elem_id="txt2img_hires_fix_row1"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - - with FormRow(elem_id="txt2img_hires_fix_row2"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="txt2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() - - hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] - for input in hr_resolution_preview_inputs: - input.change( - fn=calc_resolution_hires, - inputs=hr_resolution_preview_inputs, - outputs=[hr_final_resolution], - show_progress=False, - ) - input.change( - None, - _js="onCalcResolutionHires", - inputs=hr_resolution_preview_inputs, - outputs=[], - show_progress=False, - ) - - txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - txt2img_args = dict( - fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), - _js="submit", - inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - denoising_strength, - hr_scale, - hr_upscaler, - hr_second_pass_steps, - hr_resize_x, - hr_resize_y, - ] + custom_inputs, - - outputs=[ - txt2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) - - txt_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ] - ) - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - show_progress = False, - ) - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (hr_scale, "Hires upscale"), - (hr_upscaler, "Hires upscaler"), - (hr_second_pass_steps, "Hires steps"), - (hr_resize_x, "Hires resize-1"), - (hr_resize_y, "Hires resize-2"), - *modules.scripts.scripts_txt2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) - - txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, - steps, - sampler_index, - cfg_scale, - seed, - width, - height, - ] - - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) - - with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') - - with FormRow().style(equal_height=False): - with gr.Column(variant='panel', elem_id="img2img_settings"): - - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) - - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) - - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state - - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) - - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") - - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): - hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") - img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") - img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") - - with FormRow(): - resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - - for category in ordered_ui_categories(): - if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") - - elif category == "dimensions": - with FormRow(): - with gr.Column(elem_id="img2img_column_size", scale=4): - width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="img2img_width") - height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") - - if opts.dimensions_and_batch_together: - with gr.Column(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "cfg": - with FormGroup(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength") - - elif category == "seed": - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') - - elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") - tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") - - elif category == "batch": - if not opts.dimensions_and_batch_together: - with FormRow(elem_id="img2img_column_batch"): - batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") - batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - - elif category == "scripts": - with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() - - img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) - parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ] - ) - - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - - img2img_args = dict( - fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), - _js="submit_img2img", - inputs=[ - dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, - init_img, - init_img_with_mask, - init_img_with_mask_orig, - init_img_inpaint, - init_mask_inpaint, - mask_mode, - steps, - sampler_index, - mask_blur, - mask_alpha, - inpainting_fill, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - resize_mode, - inpaint_full_res, - inpaint_full_res_padding, - inpainting_mask_invert, - img2img_batch_input_dir, - img2img_batch_output_dir, - ] + custom_inputs, - outputs=[ - img2img_gallery, - generation_info, - html_info, - html_log, - ], - show_progress=False, - ) - - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) - - img2img_interrogate.click( - fn=interrogate, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) - - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], - ) - - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - - img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), - (steps, "Steps"), - (sampler_index, "Sampler"), - (restore_faces, "Face restoration"), - (cfg_scale, "CFG scale"), - (seed, "Seed"), - (width, "Size-1"), - (height, "Size-2"), - (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - (seed_resize_from_h, "Seed resize from-2"), - (denoising_strength, "Denoising strength"), - (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields - ] - parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) - parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) - - modules.scripts.scripts_current = None - - with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=modules.extras.clear_cache, - inputs=[], outputs=[] - ) - - with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") - - with gr.Column(variant='panel'): - html = gr.HTML() - generation_info = gr.Textbox(visible=False, elem_id="pnginfo_generation_info") - html2 = gr.HTML() - with gr.Row(): - buttons = parameters_copypaste.create_buttons(["txt2img", "img2img", "inpaint", "extras"]) - parameters_copypaste.bind_buttons(buttons, image, generation_info) - - image.change( - fn=wrap_gradio_call(modules.extras.run_pnginfo), - inputs=[image], - outputs=[html, generation_info, html2], - ) - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - - with gr.Row(): - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") - - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") - - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") - create_refresh_button(tertiary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") - - custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - - with gr.Row(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") - save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="

See wiki for detailed explanation.

") - - with gr.Row().style(equal_height=False): - with gr.Tabs(elem_id="train_tabs"): - - with gr.Tab(label="Create embedding"): - new_embedding_name = gr.Textbox(label="Name", elem_id="train_new_embedding_name") - initialization_text = gr.Textbox(label="Initialization text", value="*", elem_id="train_initialization_text") - nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1, elem_id="train_nvpt") - overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding", elem_id="train_overwrite_old_embedding") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding = gr.Button(value="Create embedding", variant='primary', elem_id="train_create_embedding") - - with gr.Tab(label="Create hypernetwork"): - new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") - new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") - new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") - new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") - new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") - new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") - new_hypernetwork_dropout_structure = gr.Textbox("0, 0, 0", label="Enter hypernetwork Dropout structure (or empty). Recommended : 0~0.35 incrementing sequence: 0, 0.05, 0.15", placeholder="1st and last digit must be 0 and values should be between 0 and 1. ex:'0, 0.01, 0'") - overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork", elem_id="train_overwrite_old_hypernetwork") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary', elem_id="train_create_hypernetwork") - - with gr.Tab(label="Preprocess images"): - process_src = gr.Textbox(label='Source directory', elem_id="train_process_src") - process_dst = gr.Textbox(label='Destination directory', elem_id="train_process_dst") - process_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_process_width") - process_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_process_height") - preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"], elem_id="train_preprocess_txt_action") - - with gr.Row(): - process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") - process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") - process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") - process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") - - with gr.Row(visible=False) as process_split_extra_row: - process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_split_threshold") - process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="train_process_overlap_ratio") - - with gr.Row(visible=False) as process_focal_crop_row: - process_focal_crop_face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_face_weight") - process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") - process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") - process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - with gr.Row(): - interrupt_preprocessing = gr.Button("Interrupt", elem_id="train_interrupt_preprocessing") - run_preprocess = gr.Button(value="Preprocess", variant='primary', elem_id="train_run_preprocess") - - process_split.change( - fn=lambda show: gr_show(show), - inputs=[process_split], - outputs=[process_split_extra_row], - ) - - process_focal_crop.change( - fn=lambda show: gr_show(show), - inputs=[process_focal_crop], - outputs=[process_focal_crop_row], - ) - - def get_textual_inversion_template_names(): - return sorted([x for x in textual_inversion.textual_inversion_templates]) - - with gr.Tab(label="Train"): - gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") - with FormRow(): - train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) - create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") - - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) - create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") - - with FormRow(): - embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate") - hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate") - - with FormRow(): - clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"]) - clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False) - - with FormRow(): - batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size") - gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step") - - dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") - log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - - with FormRow(): - template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) - create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") - - training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") - training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") - steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") - - with FormRow(): - create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_create_image_every") - save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0, elem_id="train_save_embedding_every") - - save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True, elem_id="train_save_image_with_stored_embedding") - preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False, elem_id="train_preview_from_txt2img") - - shuffle_tags = gr.Checkbox(label="Shuffle tags by ',' when creating prompts.", value=False, elem_id="train_shuffle_tags") - tag_drop_out = gr.Slider(minimum=0, maximum=1, step=0.1, label="Drop out tags when creating prompts.", value=0, elem_id="train_tag_drop_out") - - latent_sampling_method = gr.Radio(label='Choose latent sampling method', value="once", choices=['once', 'deterministic', 'random'], elem_id="train_latent_sampling_method") - - with gr.Row(): - train_embedding = gr.Button(value="Train Embedding", variant='primary', elem_id="train_train_embedding") - interrupt_training = gr.Button(value="Interrupt", elem_id="train_interrupt_training") - train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary', elem_id="train_train_hypernetwork") - - params = script_callbacks.UiTrainTabParams(txt2img_preview_params) - - script_callbacks.ui_train_tabs_callback(params) - - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") - ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) - ti_progress = gr.HTML(elem_id="ti_progress", value="") - ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) - - create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, - inputs=[ - new_embedding_name, - initialization_text, - nvpt, - overwrite_old_embedding, - ], - outputs=[ - train_embedding_name, - ti_output, - ti_outcome, - ] - ) - - create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, - inputs=[ - new_hypernetwork_name, - new_hypernetwork_sizes, - overwrite_old_hypernetwork, - new_hypernetwork_layer_structure, - new_hypernetwork_activation_func, - new_hypernetwork_initialization_option, - new_hypernetwork_add_layer_norm, - new_hypernetwork_use_dropout, - new_hypernetwork_dropout_structure - ], - outputs=[ - train_hypernetwork_name, - ti_output, - ti_outcome, - ] - ) - - run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - process_src, - process_dst, - process_width, - process_height, - preprocess_txt_action, - process_flip, - process_split, - process_caption, - process_caption_deepbooru, - process_split_threshold, - process_overlap_ratio, - process_focal_crop, - process_focal_crop_face_weight, - process_focal_crop_entropy_weight, - process_focal_crop_edges_weight, - process_focal_crop_debug, - ], - outputs=[ - ti_output, - ti_outcome, - ], - ) - - train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_embedding_name, - embedding_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - save_image_with_stored_embedding, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), - _js="start_training_textual_inversion", - inputs=[ - train_hypernetwork_name, - hypernetwork_learn_rate, - batch_size, - gradient_step, - dataset_directory, - log_directory, - training_width, - training_height, - varsize, - steps, - clip_grad_mode, - clip_grad_value, - shuffle_tags, - tag_drop_out, - latent_sampling_method, - create_image_every, - save_embedding_every, - template_file, - preview_from_txt2img, - *txt2img_preview_params, - ], - outputs=[ - ti_output, - ti_outcome, - ] - ) - - interrupt_training.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - interrupt_preprocessing.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) - - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {str(t)} for key {key}') - - elem_id = "setting_"+key - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - - components = [] - component_dict = {} - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return gr.update(value=value), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_tab.__exit__() - - current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) - current_tab.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_tab.__exit__() - - with gr.TabItem("Actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - def request_restart(): - shared.state.interrupt() - shared.state.need_restart = True - - restart_gradio.click( - fn=request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) - - interfaces = [ - (txt2img_interface, "txt2img", "txt2img"), - (img2img_interface, "img2img", "img2img"), - (extras_interface, "Extras", "extras"), - (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), - (train_interface, "Train", "ti"), - ] - - css = "" - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - with open(cssfile, "r", encoding="utf8") as file: - css += file.read() + "\n" - - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: - css += file.read() + "\n" - - if not cmd_opts.no_progressbar_hiding: - css += css_hide_progressbar - - interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] - - extensions_interface = ui_extensions.create_ui() - interfaces += [(extensions_interface, "Extensions", "extensions")] - - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): - for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component - - parameters_copypaste.integrate_settings_paste_fields(component_dict) - parameters_copypaste.run_bind() - - with gr.Tabs(elem_id="tabs") as tabs: - for interface, label, ifid in interfaces: - with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): - interface.render() - - if os.path.exists(os.path.join(script_path, "notification.mp3")): - audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") - - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for i, k, item in quicksettings_list: - component = component_dict[k] - - component.change( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [getattr(opts, key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - ) - - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] - return results - - modelmerger_merge.click( - fn=modelmerger, - inputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - interp_method, - interp_amount, - save_as_half, - custom_name, - checkpoint_format, - ], - outputs=[ - submit_result, - primary_model_name, - secondary_model_name, - tertiary_model_name, - component_dict['sd_model_checkpoint'], - ] - ) - - ui_config_file = cmd_opts.ui_config_file - ui_settings = {} - settings_count = len(ui_settings) - error_loading = False - - try: - if os.path.exists(ui_config_file): - with open(ui_config_file, "r", encoding="utf8") as file: - ui_settings = json.load(file) - except Exception: - error_loading = True - print("Error loading settings:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - def loadsave(path, x): - def apply_field(obj, field, condition=None, init_field=None): - key = path + "/" + field - - if getattr(obj, 'custom_script_source', None) is not None: - key = 'customscript/' + obj.custom_script_source + '/' + key - - if getattr(obj, 'do_not_save_to_config', False): - return - - saved_value = ui_settings.get(key, None) - if saved_value is None: - ui_settings[key] = getattr(obj, field) - elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') - else: - setattr(obj, field, saved_value) - if init_field is not None: - init_field(saved_value) - - if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown] and x.visible: - apply_field(x, 'visible') - - if type(x) == gr.Slider: - apply_field(x, 'value') - apply_field(x, 'minimum') - apply_field(x, 'maximum') - apply_field(x, 'step') - - if type(x) == gr.Radio: - apply_field(x, 'value', lambda val: val in x.choices) - - if type(x) == gr.Checkbox: - apply_field(x, 'value') - - if type(x) == gr.Textbox: - apply_field(x, 'value') - - if type(x) == gr.Number: - apply_field(x, 'value') - - if type(x) == gr.Dropdown: - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) - - visit(txt2img_interface, loadsave, "txt2img") - visit(img2img_interface, loadsave, "img2img") - visit(extras_interface, loadsave, "extras") - visit(modelmerger_interface, loadsave, "modelmerger") - visit(train_interface, loadsave, "train") - - if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)): - with open(ui_config_file, "w", encoding="utf8") as file: - json.dump(ui_settings, file, indent=4) - - return demo - - -def reload_javascript(): - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' - - scripts_list = modules.scripts.list_scripts("javascript", ".js") - - for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" - - if cmd_opts.theme is not None: - javascript += f"\n\n" - - javascript += f"\n" - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - -def versions_html(): - import torch - import launch - - python_version = ".".join([str(x) for x in sys.version_info[0:3]]) - commit = launch.commit_hash() - short_commit = commit[0:8] - - if shared.xformers_available: - import xformers - xformers_version = xformers.__version__ - else: - xformers_version = "N/A" - - return f""" -python: {python_version} - •  -torch: {torch.__version__} - •  -xformers: {xformers_version} - •  -gradio: {gr.__version__} - •  -commit: {short_commit} -""" -- cgit v1.2.3 From 76a21b9626b7556638db188c157e3e8036803326 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Tue, 10 Jan 2023 12:46:35 +0300 Subject: clear envvar, add assertion --- launch.py | 1 + test/basic_features/txt2img_test.py | 1 + 2 files changed, 2 insertions(+) diff --git a/launch.py b/launch.py index 49b91b1f..bcbb792c 100644 --- a/launch.py +++ b/launch.py @@ -282,6 +282,7 @@ def tests(test_dir): print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}") + os.environ['COMMANDLINE_ARGS'] = "" with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr: proc = subprocess.Popen([sys.executable, *sys.argv], stdout=stdout, stderr=stderr) diff --git a/test/basic_features/txt2img_test.py b/test/basic_features/txt2img_test.py index 5b27a7ec..5aa43a44 100644 --- a/test/basic_features/txt2img_test.py +++ b/test/basic_features/txt2img_test.py @@ -43,6 +43,7 @@ class TestTxt2ImgWorking(unittest.TestCase): def test_txt2img_with_complex_prompt_performed(self): self.simple_txt2img["prompt"] = "((emphasis)), (emphasis1:1.1), [to:1], [from::2], [from:to:0.3], [alt|alt1]" + self.assertEqual(requests.post(self.url_txt2img, json=self.simple_txt2img).status_code, 200) def test_txt2img_not_square_image_performed(self): self.simple_txt2img["height"] = 128 -- cgit v1.2.3 From 0c3feb202c5714abd50d879c1db2cd9a71ce93e3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 14:08:29 +0300 Subject: disable torch weight initialization and CLIP downloading/reading checkpoint to speedup creating sd model from config --- modules/sd_disable_initialization.py | 44 ++++++++++++++++++++++++++++++++++++ modules/sd_models.py | 5 ++-- 2 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 modules/sd_disable_initialization.py diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py new file mode 100644 index 00000000..c9a3b5e4 --- /dev/null +++ b/modules/sd_disable_initialization.py @@ -0,0 +1,44 @@ +import ldm.modules.encoders.modules +import open_clip +import torch + + +class DisableInitialization: + """ + When an object of this class enters a `with` block, it starts preventing torch's layer initialization + functions from working, and changes CLIP and OpenCLIP to not download model weights. When it leaves, + reverts everything to how it was. + + Use like this: + ``` + with DisableInitialization(): + do_things() + ``` + """ + + def __enter__(self): + def do_nothing(*args, **kwargs): + pass + + def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs): + return self.create_model_and_transforms(*args, pretrained=None, **kwargs) + + def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): + return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) + + self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_ + self.init_no_grad_normal = torch.nn.init._no_grad_normal_ + self.create_model_and_transforms = open_clip.create_model_and_transforms + self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained + + torch.nn.init.kaiming_uniform_ = do_nothing + torch.nn.init._no_grad_normal_ = do_nothing + open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained + ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform + torch.nn.init._no_grad_normal_ = self.init_no_grad_normal + open_clip.create_model_and_transforms = self.create_model_and_transforms + ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained + diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a6d55ca..ee241032 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -13,7 +13,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -319,7 +319,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False - sd_model = instantiate_from_config(sd_config.model) + with sd_disable_initialization.DisableInitialization(): + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.3 From ce3f639ec8758ce2bc90483336361d2dc25acd3a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 16:51:04 +0300 Subject: add more stuff to ignore when creating model from config prevent .vae.safetensors files from being listed as stable diffusion models --- modules/modelloader.py | 4 +++- modules/sd_disable_initialization.py | 29 +++++++++++++++++++++++++---- modules/sd_models.py | 32 ++++++++++++++++++++++++++++---- 3 files changed, 56 insertions(+), 9 deletions(-) diff --git a/modules/modelloader.py b/modules/modelloader.py index 6a1a7ac8..e9aa514e 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -10,7 +10,7 @@ from modules.upscaler import Upscaler from modules.paths import script_path, models_path -def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list: +def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None full_path = file if os.path.isdir(full_path): continue + if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]): + continue if len(ext_filter) != 0: model_name, extension = os.path.splitext(file) if extension not in ext_filter: diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index c9a3b5e4..9942bd7e 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -1,15 +1,19 @@ import ldm.modules.encoders.modules import open_clip import torch +import transformers.utils.hub class DisableInitialization: """ - When an object of this class enters a `with` block, it starts preventing torch's layer initialization - functions from working, and changes CLIP and OpenCLIP to not download model weights. When it leaves, - reverts everything to how it was. + When an object of this class enters a `with` block, it starts: + - preventing torch's layer initialization functions from working + - changes CLIP and OpenCLIP to not download model weights + - changes CLIP to not make requests to check if there is a new version of a file you already have - Use like this: + When it leaves the block, it reverts everything to how it was before. + + Use it like this: ``` with DisableInitialization(): do_things() @@ -26,19 +30,36 @@ class DisableInitialization: def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) + def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): + + # this file is always 404, prevent making request + if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json': + raise transformers.utils.hub.EntryNotFoundError + + try: + return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=True, **kwargs) + except Exception as e: + return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs) + self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_ self.init_no_grad_normal = torch.nn.init._no_grad_normal_ + self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_ self.create_model_and_transforms = open_clip.create_model_and_transforms self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained + self.transformers_utils_hub_get_from_cache = transformers.utils.hub.get_from_cache torch.nn.init.kaiming_uniform_ = do_nothing torch.nn.init._no_grad_normal_ = do_nothing + torch.nn.init._no_grad_uniform_ = do_nothing open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained + transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache def __exit__(self, exc_type, exc_val, exc_tb): torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform torch.nn.init._no_grad_normal_ = self.init_no_grad_normal + torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_ open_clip.create_model_and_transforms = self.create_model_and_transforms ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained + transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache diff --git a/modules/sd_models.py b/modules/sd_models.py index ee241032..1bb9088b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,6 +2,7 @@ import collections import os.path import sys import gc +import time from collections import namedtuple import torch import re @@ -61,7 +62,7 @@ def find_checkpoint_config(info): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -288,6 +289,17 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper +class Timer: + def __init__(self): + self.start = time.time() + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -319,11 +331,17 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False + timer = Timer() + with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) + elapsed_create = timer.elapsed() + load_model_weights(sd_model, checkpoint_info) + elapsed_load_weights = timer.elapsed() + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: @@ -338,7 +356,9 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) - print("Model loaded.") + elapsed_the_rest = timer.elapsed() + + print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") return sd_model @@ -349,7 +369,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model - if sd_model is None: # previous model load failed + if sd_model is None: # previous model load failed current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info @@ -371,6 +391,8 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) + timer = Timer() + try: load_model_weights(sd_model, checkpoint_info) except Exception as e: @@ -384,6 +406,8 @@ def reload_model_weights(sd_model=None, info=None): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("Weights loaded.") + elapsed = timer.elapsed() + + print(f"Weights loaded in {elapsed:.1f}s.") return sd_model -- cgit v1.2.3 From 0f8603a55988d22616b17140e6c4a7e9d0736af5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 17:46:59 +0300 Subject: add support for transformers==4.25.1 add fallback for when quick model creation fails --- modules/sd_disable_initialization.py | 42 ++++++++++++++++++++++++++++++------ modules/sd_models.py | 8 +++++-- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 9942bd7e..088ac24b 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -30,30 +30,53 @@ class DisableInitialization: def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) - def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): + def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): + args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug + return self.transformers_modeling_utils_load_pretrained_model(*args, **kwargs) + + def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): # this file is always 404, prevent making request if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json': raise transformers.utils.hub.EntryNotFoundError try: - return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=True, **kwargs) + return original(url, *args, local_files_only=True, **kwargs) except Exception as e: - return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs) + return original(url, *args, local_files_only=False, **kwargs) + + def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs): + return transformers_utils_hub_get_file_from_cache(self.transformers_utils_hub_get_from_cache, url, *args, **kwargs) + + def transformers_tokenization_utils_base_cached_file(url, *args, local_files_only=False, **kwargs): + return transformers_utils_hub_get_file_from_cache(self.transformers_tokenization_utils_base_cached_file, url, *args, **kwargs) + + def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): + return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs) self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_ self.init_no_grad_normal = torch.nn.init._no_grad_normal_ self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_ self.create_model_and_transforms = open_clip.create_model_and_transforms self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained - self.transformers_utils_hub_get_from_cache = transformers.utils.hub.get_from_cache + self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None) + self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None) + self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None) + self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None) torch.nn.init.kaiming_uniform_ = do_nothing torch.nn.init._no_grad_normal_ = do_nothing torch.nn.init._no_grad_uniform_ = do_nothing open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained - transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache + if self.transformers_modeling_utils_load_pretrained_model is not None: + transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model + if self.transformers_tokenization_utils_base_cached_file is not None: + transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file + if self.transformers_configuration_utils_cached_file is not None: + transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file + if self.transformers_utils_hub_get_from_cache is not None: + transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache def __exit__(self, exc_type, exc_val, exc_tb): torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform @@ -61,5 +84,12 @@ class DisableInitialization: torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_ open_clip.create_model_and_transforms = self.create_model_and_transforms ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained - transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache + if self.transformers_modeling_utils_load_pretrained_model is not None: + transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model + if self.transformers_tokenization_utils_base_cached_file is not None: + transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file + if self.transformers_configuration_utils_cached_file is not None: + transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file + if self.transformers_utils_hub_get_from_cache is not None: + transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache diff --git a/modules/sd_models.py b/modules/sd_models.py index 1bb9088b..b5bc12f0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -333,7 +333,11 @@ def load_model(checkpoint_info=None): timer = Timer() - with sd_disable_initialization.DisableInitialization(): + try: + with sd_disable_initialization.DisableInitialization(): + sd_model = instantiate_from_config(sd_config.model) + except Exception as e: + print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) elapsed_create = timer.elapsed() -- cgit v1.2.3 From e2c8584f753b6fe8116f3032360a3b02e8398349 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Tue, 10 Jan 2023 22:26:49 +0300 Subject: make VENV envvar accept absolute path instead of relative --- webui.bat | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/webui.bat b/webui.bat index d4d626e2..3a3e310a 100644 --- a/webui.bat +++ b/webui.bat @@ -1,7 +1,7 @@ @echo off if not defined PYTHON (set PYTHON=python) -if not defined VENV_DIR (set VENV_DIR=venv) +if not defined VENV_DIR (set VENV_DIR=%~dp0\venv) set ERROR_REPORTING=FALSE @@ -26,7 +26,7 @@ echo Unable to create venv in directory %VENV_DIR% goto :show_stdout_stderr :activate_venv -set PYTHON="%~dp0%VENV_DIR%\Scripts\Python.exe" +set PYTHON="%VENV_DIR%\Scripts\Python.exe" echo venv %PYTHON% if [%ACCELERATE%] == ["True"] goto :accelerate goto :launch @@ -35,7 +35,7 @@ goto :launch :accelerate echo "Checking for accelerate" -set ACCELERATE="%~dp0%VENV_DIR%\Scripts\accelerate.exe" +set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe" if EXIST %ACCELERATE% goto :accelerate_launch :launch -- cgit v1.2.3 From 29fb5327640465fc83111e2170c5d8aa2b15266c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 23:47:02 +0300 Subject: change color selector in settings to be part of form --- modules/shared.py | 4 ++-- modules/ui_components.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index aa37c8ce..264264a6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors +from modules import localization, sd_vae, extensions, script_loading, errors, ui_components from modules.paths import models_path, script_path, sd_path @@ -387,7 +387,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), - "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", gr.ColorPicker, {}), + "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), diff --git a/modules/ui_components.py b/modules/ui_components.py index cac001dc..97acff06 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -31,3 +31,9 @@ class FormHTML(gr.HTML, gr.components.FormComponent): def get_block_name(self): return "html" + +class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): + """Same as gr.ColorPicker but fits inside gradio forms""" + + def get_block_name(self): + return "colorpicker" -- cgit v1.2.3 From 6be644fa04ce1542f3a01804310cbbc0a4a91620 Mon Sep 17 00:00:00 2001 From: dan Date: Wed, 11 Jan 2023 05:31:58 +0800 Subject: Enable batch_size>1 for mixed-sized training --- modules/textual_inversion/dataset.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index fa48708e..b47414f3 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -3,8 +3,10 @@ import numpy as np import PIL import torch from PIL import Image -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import transforms +from collections import defaultdict +from random import shuffle, choices import random import tqdm @@ -45,12 +47,12 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - assert batch_size == 1 or not varsize, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.shuffle_tags = shuffle_tags self.tag_drop_out = tag_drop_out + groups = defaultdict(list) print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): @@ -103,13 +105,14 @@ class PersonalizedBase(Dataset): if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): with devices.autocast(): entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) - + groups[image.size].append(len(self.dataset)) self.dataset.append(entry) del torchdata del latent_dist del latent_sample self.length = len(self.dataset) + self.groups = list(groups.values()) assert self.length > 0, "No images have been found in the dataset." self.batch_size = min(batch_size, self.length) self.gradient_step = min(gradient_step, self.length // self.batch_size) @@ -137,9 +140,34 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry +class GroupedBatchSampler(Sampler): + def __init__(self, data_source: PersonalizedBase, batch_size: int): + n = len(data_source) + self.groups = data_source.groups + self.len = n_batch = n // batch_size + expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] + self.base = [int(e) // batch_size for e in expected] + self.n_rand_batches = nrb = n_batch - sum(self.base) + self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] + self.batch_size = batch_size + def __len__(self): + return self.len + def __iter__(self): + b = self.batch_size + for g in self.groups: + shuffle(g) + batches = [] + for g in self.groups: + batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) + for _ in range(self.n_rand_batches): + rand_group = choices(self.groups, self.probs)[0] + batches.append(choices(rand_group, k=b)) + shuffle(batches) + yield from batches + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): - super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory) + super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) if latent_sampling_method == "random": self.collate_fn = collate_wrapper_random else: -- cgit v1.2.3 From 9cfd10cdefc7b2966b8e42fbb0e05735967cf87b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 01:27:00 +0300 Subject: repair #6612 --- webui.bat | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/webui.bat b/webui.bat index 3a3e310a..e6a7a429 100644 --- a/webui.bat +++ b/webui.bat @@ -1,7 +1,7 @@ @echo off if not defined PYTHON (set PYTHON=python) -if not defined VENV_DIR (set VENV_DIR=%~dp0\venv) +if not defined VENV_DIR (set VENV_DIR=%~dp0%venv) set ERROR_REPORTING=FALSE @@ -13,16 +13,16 @@ echo Couldn't launch python goto :show_stdout_stderr :start_venv -if [%VENV_DIR%] == [-] goto :skip_venv +if ["%VENV_DIR%"] == ["-"] goto :skip_venv -dir %VENV_DIR%\Scripts\Python.exe >tmp/stdout.txt 2>tmp/stderr.txt +dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt if %ERRORLEVEL% == 0 goto :activate_venv for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i" echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME% -%PYTHON_FULLNAME% -m venv %VENV_DIR% >tmp/stdout.txt 2>tmp/stderr.txt +%PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt if %ERRORLEVEL% == 0 goto :activate_venv -echo Unable to create venv in directory %VENV_DIR% +echo Unable to create venv in directory "%VENV_DIR%" goto :show_stdout_stderr :activate_venv -- cgit v1.2.3 From f9706acf431f77e0ce9e4270e5be7299922ee963 Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Tue, 10 Jan 2023 18:40:34 -0700 Subject: Support loading textual inversion embeddings from safetensors files --- modules/textual_inversion/textual_inversion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5420903f..3866c154 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -9,6 +9,7 @@ import tqdm import html import datetime import csv +import safetensors.torch from PIL import Image, PngImagePlugin @@ -150,6 +151,8 @@ class EmbeddingDatabase: name = data.get('name', name) elif ext in ['.BIN', '.PT']: data = torch.load(path, map_location="cpu") + elif ext in ['.SAFETENSORS']: + data = safetensors.torch.load_file(path, device="cpu") else: return -- cgit v1.2.3 From 5830095b73515fc49b3fd567048470005191ec34 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 10 Jan 2023 21:43:24 -0500 Subject: Add old prompt parser compat option --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/shared.py b/modules/shared.py index 264264a6..b61bbd3f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -400,6 +400,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), + "use_old_prompt_parser_default_step_transformer": OptionInfo(False, "Use old prompt parser default step transformer. In particular, alternating words that contained emphasis were not parsed correctly. Useful to reproduce old seeds."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { -- cgit v1.2.3 From 7e45fba55b24166501033a221e6268545fa47fbe Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 10 Jan 2023 21:47:03 -0500 Subject: Fix prompt parser default step transformer w/ test --- modules/prompt_parser.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index f70872c4..b69f1425 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -3,6 +3,11 @@ from collections import namedtuple from typing import List import lark +try: + from modules.shared import opts +except: + pass + # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] @@ -49,6 +54,8 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): [[5, 'a c'], [10, 'a {b|d{ c']] >>> g("((a][:b:c [d:3]") [[3, '((a][:b:c '], [10, '((a][:b:c d']] + >>> g("[a|(b:1.1)]") + [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']] """ def collect_steps(steps, tree): @@ -84,7 +91,13 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): yield args[0].value def __default__(self, data, children, meta): for child in children: - yield from child + try: + if opts.use_old_prompt_parser_default_step_transformer: + yield from child + else: + yield child + except: + yield child return AtStep().transform(tree) def get_schedule(prompt): -- cgit v1.2.3 From 37a230112198adcb3f24d59b399cff342a6d479e Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 10 Jan 2023 20:30:09 -0800 Subject: Expose the compiled class module of scripts to extensions --- modules/scripts.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/scripts.py b/modules/scripts.py index 35164093..4ffc369b 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -152,7 +152,7 @@ def basedir(): scripts_data = [] ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) -ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"]) +ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) def list_scripts(scriptdirname, extension): @@ -206,7 +206,7 @@ def load_scripts(): for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir)) + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) except Exception: print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) @@ -241,7 +241,7 @@ class ScriptRunner: self.alwayson_scripts.clear() self.selectable_scripts.clear() - for script_class, path, basedir in scripts_data: + for script_class, path, basedir, script_module in scripts_data: script = script_class() script.filename = path script.is_txt2img = not is_img2img -- cgit v1.2.3 From 954091697fce7a1b7997d5f3d73551f793f6bebc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 09:10:07 +0300 Subject: add an option to copy config from one of models in checkpoint merger --- modules/extras.py | 30 +++++++++++++++++++++++++++++- modules/ui.py | 9 ++++++--- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 7407bfe3..a03d558e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -3,6 +3,7 @@ import math import os import sys import traceback +import shutil import numpy as np from PIL import Image @@ -248,7 +249,32 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): +def create_config(ckpt_result, config_source, a, b, c): + def config(x): + return sd_models.find_checkpoint_config(x) if x else None + + if config_source == 0: + cfg = config(a) or config(b) or config(c) + elif config_source == 1: + cfg = config(b) + elif config_source == 2: + cfg = config(c) + else: + cfg = None + + if cfg is None: + return + + filename, _ = os.path.splitext(ckpt_result) + checkpoint_filename = filename + ".yaml" + + print("Copying config:") + print(" from:", cfg) + print(" to:", checkpoint_filename) + shutil.copyfile(cfg, checkpoint_filename) + + +def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): shared.state.begin() shared.state.job = 'model-merge' @@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam sd_models.list_models() + create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + print("Checkpoint saved.") shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.end() diff --git a/modules/ui.py b/modules/ui.py index 3c458ce8..82f5dd7c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1129,7 +1129,7 @@ def create_ui(): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - with gr.Row(): + with FormRow(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") @@ -1143,11 +1143,13 @@ def create_ui(): interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") - with gr.Row(): + with FormRow(): checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') + config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) @@ -1703,6 +1705,7 @@ def create_ui(): save_as_half, custom_name, checkpoint_format, + config_source, ], outputs=[ submit_result, -- cgit v1.2.3 From 4fdacd31e48c6a7a35c1c25c559932585e8addde Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 10:24:56 +0300 Subject: possible fix for fallback for fast model creation from config --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index b5bc12f0..a0a8a909 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -337,6 +337,9 @@ def load_model(checkpoint_info=None): with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) except Exception as e: + pass + + if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.3 From 1a23dc32ac5e16fac10115cafd0b841abd06e59f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 10:34:36 +0300 Subject: possible fix for fallback for fast model creation from config, attempt 2 --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/sd_models.py b/modules/sd_models.py index a0a8a909..084ba7fa 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -333,6 +333,7 @@ def load_model(checkpoint_info=None): timer = Timer() + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.3 From b202714b65aa2145ff965ed4f197ac1516093f34 Mon Sep 17 00:00:00 2001 From: Alexey Shirokov <40300551+demiurge-ash@users.noreply.github.com> Date: Wed, 11 Jan 2023 11:41:50 +0300 Subject: Fix keyboard navigation in modal image viewer --- javascript/imageviewer.js | 1 + 1 file changed, 1 insertion(+) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index b7bc2fe1..1f29ad7b 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -151,6 +151,7 @@ function showGalleryImage() { e.addEventListener('mousedown', function (evt) { if(!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) + evt.preventDefault() showModal(evt) }, true); } -- cgit v1.2.3 From ab388d6f8bf51338de1950b3907c324b0ff6a872 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 11 Jan 2023 08:59:47 -0500 Subject: Remove compat option check for prompt parser --- modules/prompt_parser.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index b69f1425..870218db 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -3,11 +3,6 @@ from collections import namedtuple from typing import List import lark -try: - from modules.shared import opts -except: - pass - # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" # will be represented with prompt_schedule like this (assuming steps=100): # [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] @@ -91,13 +86,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): yield args[0].value def __default__(self, data, children, meta): for child in children: - try: - if opts.use_old_prompt_parser_default_step_transformer: - yield from child - else: - yield child - except: - yield child + yield child return AtStep().transform(tree) def get_schedule(prompt): -- cgit v1.2.3 From 0b38b72d31ead82c7d0998a29e50da90073831f7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 11 Jan 2023 09:01:37 -0500 Subject: Remove compat option for prompt parser --- modules/shared.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index b61bbd3f..264264a6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -400,7 +400,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), - "use_old_prompt_parser_default_step_transformer": OptionInfo(False, "Use old prompt parser default step transformer. In particular, alternating words that contained emphasis were not parsed correctly. Useful to reproduce old seeds."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { -- cgit v1.2.3 From 39ea251945d70efcf9b59d44eb0e71269d754aa4 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 11 Jan 2023 10:23:51 -0500 Subject: add textinfo to progress response --- modules/api/api.py | 4 ++-- modules/api/models.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 6c564ad8..5767ba90 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -286,7 +286,7 @@ class Api: # copy from check_progress_call of ui.py if shared.state.job_count == 0: - return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict()) + return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict(), textinfo=shared.state.textinfo) # avoid dividing zero progress = 0.01 @@ -308,7 +308,7 @@ class Api: if shared.state.current_image and not req.skip_current_image: current_image = encode_pil_to_base64(shared.state.current_image) - return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) + return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image, textinfo=shared.state.textinfo) def interrogateapi(self, interrogatereq: InterrogateRequest): image_b64 = interrogatereq.image diff --git a/modules/api/models.py b/modules/api/models.py index 034b4aa0..c78095ca 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -168,6 +168,7 @@ class ProgressResponse(BaseModel): eta_relative: float = Field(title="ETA in secs") state: dict = Field(title="State", description="The current state snapshot") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") + textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") class InterrogateRequest(BaseModel): image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.") -- cgit v1.2.3 From 3f43d8a966ba8462ba019a5ad573f94508cd45f8 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 11 Jan 2023 10:28:55 -0500 Subject: set descriptions --- modules/hypernetworks/hypernetwork.py | 4 +++- modules/textual_inversion/preprocess.py | 7 ++++++- modules/textual_inversion/textual_inversion.py | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 300d3975..194679e8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -619,7 +619,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, epoch_num = hypernetwork.step // steps_per_epoch epoch_step = hypernetwork.step % steps_per_epoch - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" + pbar.set_description(description) + shared.state.textinfo = description if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index feb876c6..3c1042ad 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -135,7 +135,8 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre params.process_caption_deepbooru = process_caption_deepbooru params.preprocess_txt_action = preprocess_txt_action - for index, imagefile in enumerate(tqdm.tqdm(files)): + pbar = tqdm.tqdm(files) + for index, imagefile in enumerate(pbar): params.subindex = 0 filename = os.path.join(src, imagefile) try: @@ -143,6 +144,10 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre except Exception: continue + description = f"Preprocessing [Image {index}/{len(files)}]" + pbar.set_description(description) + shared.state.textinfo = description + params.src = filename existing_caption = None diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 3866c154..b915b091 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -476,7 +476,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ epoch_num = embedding.step // steps_per_epoch epoch_step = embedding.step % steps_per_epoch - pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") + description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" + pbar.set_description(description) + shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' -- cgit v1.2.3 From 4bd490727e156ff53107d53416d6b89be86f2a62 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 18:54:04 +0300 Subject: fix for an error caused by skipping initialization, for realsies this time: TypeError: expected str, bytes or os.PathLike object, not NoneType --- modules/sd_disable_initialization.py | 71 ++++++++++++++++-------------------- modules/sd_models.py | 1 + 2 files changed, 33 insertions(+), 39 deletions(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 088ac24b..c72d8efc 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -20,6 +20,19 @@ class DisableInitialization: ``` """ + def __init__(self): + self.replaced = [] + + def replace(self, obj, field, func): + original = getattr(obj, field, None) + if original is None: + return None + + self.replaced.append((obj, field, original)) + setattr(obj, field, func) + + return original + def __enter__(self): def do_nothing(*args, **kwargs): pass @@ -37,11 +50,14 @@ class DisableInitialization: def transformers_utils_hub_get_file_from_cache(original, url, *args, **kwargs): # this file is always 404, prevent making request - if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json': - raise transformers.utils.hub.EntryNotFoundError + if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json' or url == 'openai/clip-vit-large-patch14' and args[0] == 'added_tokens.json': + return None try: - return original(url, *args, local_files_only=True, **kwargs) + res = original(url, *args, local_files_only=True, **kwargs) + if res is None: + res = original(url, *args, local_files_only=False, **kwargs) + return res except Exception as e: return original(url, *args, local_files_only=False, **kwargs) @@ -54,42 +70,19 @@ class DisableInitialization: def transformers_configuration_utils_cached_file(url, *args, local_files_only=False, **kwargs): return transformers_utils_hub_get_file_from_cache(self.transformers_configuration_utils_cached_file, url, *args, **kwargs) - self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_ - self.init_no_grad_normal = torch.nn.init._no_grad_normal_ - self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_ - self.create_model_and_transforms = open_clip.create_model_and_transforms - self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained - self.transformers_modeling_utils_load_pretrained_model = getattr(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', None) - self.transformers_tokenization_utils_base_cached_file = getattr(transformers.tokenization_utils_base, 'cached_file', None) - self.transformers_configuration_utils_cached_file = getattr(transformers.configuration_utils, 'cached_file', None) - self.transformers_utils_hub_get_from_cache = getattr(transformers.utils.hub, 'get_from_cache', None) - - torch.nn.init.kaiming_uniform_ = do_nothing - torch.nn.init._no_grad_normal_ = do_nothing - torch.nn.init._no_grad_uniform_ = do_nothing - open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained - ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained - if self.transformers_modeling_utils_load_pretrained_model is not None: - transformers.modeling_utils.PreTrainedModel._load_pretrained_model = transformers_modeling_utils_load_pretrained_model - if self.transformers_tokenization_utils_base_cached_file is not None: - transformers.tokenization_utils_base.cached_file = transformers_tokenization_utils_base_cached_file - if self.transformers_configuration_utils_cached_file is not None: - transformers.configuration_utils.cached_file = transformers_configuration_utils_cached_file - if self.transformers_utils_hub_get_from_cache is not None: - transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache + self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing) + self.replace(torch.nn.init, '_no_grad_normal_', do_nothing) + self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing) + self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained) + self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained) + self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model) + self.transformers_tokenization_utils_base_cached_file = self.replace(transformers.tokenization_utils_base, 'cached_file', transformers_tokenization_utils_base_cached_file) + self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) + self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): - torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform - torch.nn.init._no_grad_normal_ = self.init_no_grad_normal - torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_ - open_clip.create_model_and_transforms = self.create_model_and_transforms - ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained - if self.transformers_modeling_utils_load_pretrained_model is not None: - transformers.modeling_utils.PreTrainedModel._load_pretrained_model = self.transformers_modeling_utils_load_pretrained_model - if self.transformers_tokenization_utils_base_cached_file is not None: - transformers.utils.hub.cached_file = self.transformers_tokenization_utils_base_cached_file - if self.transformers_configuration_utils_cached_file is not None: - transformers.utils.hub.cached_file = self.transformers_configuration_utils_cached_file - if self.transformers_utils_hub_get_from_cache is not None: - transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache + for obj, field, original in self.replaced: + setattr(obj, field, original) + + self.replaced.clear() diff --git a/modules/sd_models.py b/modules/sd_models.py index 084ba7fa..c466f273 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -334,6 +334,7 @@ def load_model(checkpoint_info=None): timer = Timer() sd_model = None + try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.3 From 0b8911d883118daa54f7735c5b753b5575d9f943 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 20:33:24 +0300 Subject: img2img UI rework: obsolete --gradio-img2img-tool --gradio-inpaint-tool and always show all tools each in own tab --- modules/img2img.py | 58 ++++++++++++++---------------- modules/shared.py | 4 +-- modules/ui.py | 103 +++++++++++++++++++++++++++-------------------------- style.css | 4 ++- 4 files changed, 84 insertions(+), 85 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index ca58b5d8..f62783c6 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,38 +59,34 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_with_mask_orig, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): - is_inpaint = mode == 1 - is_batch = mode == 2 - - if is_inpaint: - # Drawn mask - if mask_mode == 0: - is_mask_sketch = isinstance(init_img_with_mask, dict) - is_mask_paint = not is_mask_sketch - if is_mask_sketch: - # Sketch: mask iff. not transparent - image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] - alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') - else: - # Color-sketch: mask iff. painted over - image = init_img_with_mask - orig = init_img_with_mask_orig or init_img_with_mask - pred = np.any(np.array(image) != np.array(orig), axis=-1) - mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") - mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) - blur = ImageFilter.GaussianBlur(mask_blur) - image = Image.composite(image.filter(blur), orig, mask.filter(blur)) - - image = image.convert("RGB") - # Uploaded mask - else: - image = init_img_inpaint - mask = init_mask_inpaint - # No mask +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): + is_batch = mode == 5 + + if mode == 0: # img2img + image = init_img.convert("RGB") + mask = None + elif mode == 1: # img2img sketch + image = sketch.convert("RGB") + mask = None + elif mode == 2: # inpaint + image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] + alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') + mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') + image = image.convert("RGB") + elif mode == 3: # inpaint sketch + image = inpaint_color_sketch + orig = inpaint_color_sketch_orig or inpaint_color_sketch + pred = np.any(np.array(image) != np.array(orig), axis=-1) + mask = Image.fromarray(pred.astype(np.uint8) * 255, "L") + mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) + blur = ImageFilter.GaussianBlur(mask_blur) + image = Image.composite(image.filter(blur), orig, mask.filter(blur)) + image = image.convert("RGB") + elif mode == 4: # inpaint upload mask + image = init_img_inpaint + mask = init_mask_inpaint else: - image = init_img + image = None mask = None # Use the EXIF orientation of photos taken by smartphones. diff --git a/modules/shared.py b/modules/shared.py index 264264a6..1c964237 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,8 +74,8 @@ parser.add_argument("--freeze-settings", action='store_true', help="disable edit parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) -parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") -parser.add_argument("--gradio-inpaint-tool", type=str, choices=["sketch", "color-sketch"], default="sketch", help="gradio inpainting editor: can be either sketch to only blur/noise the input, or color-sketch to paint over it") +parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') +parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) diff --git a/modules/ui.py b/modules/ui.py index 82f5dd7c..e86a624b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -795,53 +795,67 @@ def create_ui(): with FormRow().style(equal_height=False): with gr.Column(variant='panel', elem_id="img2img_settings"): + with gr.Tabs(elem_id="mode_img2img"): + with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480) - with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: - with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab"): - init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool, image_mode="RGBA").style(height=480) + with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: + sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) - with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab"): - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_inpaint_tool, image_mode="RGBA").style(height=480) - init_img_with_mask_orig = gr.State(None) + with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) - use_color_sketch = cmd_opts.gradio_inpaint_tool == "color-sketch" - if use_color_sketch: - def update_orig(image, state): - if image is not None: - same_size = state is not None and state.size == image.size - has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) - edited = same_size and has_exact_match - return image if not edited or state is None else state + with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) + inpaint_color_sketch_orig = gr.State(None) - init_img_with_mask.change(update_orig, [init_img_with_mask, init_img_with_mask_orig], init_img_with_mask_orig) + def update_orig(image, state): + if image is not None: + same_size = state is not None and state.size == image.size + has_exact_match = np.any(np.all(np.array(image) == np.array(state), axis=-1)) + edited = same_size and has_exact_match + return image if not edited or state is None else state - init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_base") - init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False, elem_id="img_inpaint_mask") + inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", interactive=use_color_sketch, visible=use_color_sketch, elem_id="img2img_mask_alpha") - - with FormRow(): - mask_mode = gr.Radio(label="Mask source", choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: + init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") + init_mask_inpaint = gr.Image(label="Mask", source="upload", interactive=True, type="pil", elem_id="img_inpaint_mask") - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - with gr.TabItem('Batch img2img', id='batch', elem_id="img2img_batch_tab"): + with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with FormRow(): + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") @@ -900,20 +914,6 @@ def create_ui(): ] ) - mask_mode.change( - lambda mode, img: { - init_img_with_mask: gr_show(mode == 0), - init_img_inpaint: gr_show(mode == 1), - init_mask_inpaint: gr_show(mode == 1), - }, - inputs=[mask_mode, init_img_with_mask], - outputs=[ - init_img_with_mask, - init_img_inpaint, - init_mask_inpaint, - ], - ) - img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", @@ -924,11 +924,12 @@ def create_ui(): img2img_prompt_style, img2img_prompt_style2, init_img, + sketch, init_img_with_mask, - init_img_with_mask_orig, + inpaint_color_sketch, + inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, - mask_mode, steps, sampler_index, mask_blur, diff --git a/style.css b/style.css index ec5e4182..ffd6307f 100644 --- a/style.css +++ b/style.css @@ -557,7 +557,9 @@ canvas[key="mask"] { } #img2img_image, #img2img_image > .h-60, #img2img_image > .h-60 > div, #img2img_image > .h-60 > div > img, -img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img +#img2img_sketch, #img2img_sketch > .h-60, #img2img_sketch > .h-60 > div, #img2img_sketch > .h-60 > div > img, +#img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h-60 > div > img, +#inpaint_sketch, #inpaint_sketch > .h-60, #inpaint_sketch > .h-60 > div, #inpaint_sketch > .h-60 > div > img { height: 480px !important; max-height: 480px !important; -- cgit v1.2.3 From d52a80f7f7da160c73afd067c8f1bf491391f994 Mon Sep 17 00:00:00 2001 From: Shondoit Date: Thu, 12 Jan 2023 09:22:29 +0100 Subject: Allow creation of zero vectors for TI --- modules/textual_inversion/textual_inversion.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b915b091..853246a6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -248,11 +248,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): with devices.autocast(): cond_model([""]) # will send cond model to GPU if lowvram/medvram is active - embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token) + #cond_model expects at least some text, so we provide '*' as backup. + embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) - for i in range(num_vectors_per_token): - vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + #Only copy if we provided an init_text, otherwise keep vectors as zeros + if init_text: + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) -- cgit v1.2.3 From d48dcbd2b29eab492d53d78f482356d78e5beb19 Mon Sep 17 00:00:00 2001 From: Shondoit Date: Thu, 12 Jan 2023 09:53:35 +0100 Subject: Add zero vector feature to hints.js Also add the note that some tokens might be skipped. Not everyone is aware of this. --- javascript/hints.js | 1 + 1 file changed, 1 insertion(+) diff --git a/javascript/hints.js b/javascript/hints.js index 856e1389..244bfde2 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -92,6 +92,7 @@ titles = { "Weighted sum": "Result = A * (1 - M) + B * M", "Add difference": "Result = A + (B - C) * M", + "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors", "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.", -- cgit v1.2.3 From 5623a3e7b1beed61f3ae6829a05b7b861d70e203 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 12 Jan 2023 19:47:33 +0300 Subject: fix send to inpaint sending you to wrong place --- javascript/ui.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/ui.js b/javascript/ui.js index ee226927..a41dd26f 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -54,7 +54,7 @@ function switch_to_img2img(){ function switch_to_inpaint(){ gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click(); - gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click(); + gradioApp().getElementById('mode_img2img').querySelectorAll('button')[2].click(); return args_to_array(arguments); } -- cgit v1.2.3 From 6ffefdcc9f47b66cbc543690d97cbf8327f4ba58 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 12 Jan 2023 19:47:44 +0300 Subject: fix js errors when restarting UI --- script.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/script.js b/script.js index 0e117d06..21960d91 100644 --- a/script.js +++ b/script.js @@ -1,5 +1,6 @@ function gradioApp() { - const gradioShadowRoot = document.getElementsByTagName('gradio-app')[0].shadowRoot + const elems = document.getElementsByTagName('gradio-app') + const gradioShadowRoot = elems.length == 0 ? null : elems[0].shadowRoot return !!gradioShadowRoot ? gradioShadowRoot : document; } -- cgit v1.2.3 From 88416ab5ff787eec3b9962b43b5e544bb75fbad6 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 12 Jan 2023 13:46:59 -0800 Subject: Fix extension parameters not being saved to last used parameters --- modules/processing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index f04a0e1e..ae04cab7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -531,16 +531,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: def infotext(iteration=0, position_in_batch=0): return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: - processed = Processed(p, [], p.seed, "") - file.write(processed.infotext(p, 0)) - if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() if p.scripts is not None: p.scripts.process(p) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + processed = Processed(p, [], p.seed, "") + file.write(processed.infotext(p, 0)) + infotexts = [] output_images = [] -- cgit v1.2.3 From 6c88eaed4f5efca54a882eb1f8f30f01f350332a Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 12 Jan 2023 13:50:09 -0800 Subject: Add script callback for fixing infotext parameters --- modules/generation_parameters_copypaste.py | 3 ++- modules/script_callbacks.py | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 620aa606..593d99ef 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -7,7 +7,7 @@ from pathlib import Path import gradio as gr from modules.shared import script_path -from modules import shared, ui_tempdir +from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image @@ -298,6 +298,7 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None): prompt = file.read() params = parse_generation_parameters(prompt) + script_callbacks.infotext_pasted_callback(prompt, params) res = [] for output, key in paste_fields: diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 608c5300..a9e19236 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,7 +2,7 @@ import sys import traceback from collections import namedtuple import inspect -from typing import Optional +from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_infotext_pasted=[], callbacks_script_unloaded=[], ) @@ -172,6 +173,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): + for c in callback_map['callbacks_infotext_pasted']: + try: + c.callback(infotext, params) + except Exception: + report_exception(c, 'infotext_pasted') + + def script_unloaded_callback(): for c in reversed(callback_map['callbacks_script_unloaded']): try: @@ -290,6 +299,15 @@ def on_image_grid(callback): add_callback(callback_map['callbacks_image_grid'], callback) +def on_infotext_pasted(callback): + """register a function to be called before applying an infotext. + The callback is called with two arguments: + - infotext: str - raw infotext. + - result: Dict[str, any] - parsed infotext parameters. + """ + add_callback(callback_map['callbacks_infotext_pasted'], callback) + + def on_script_unloaded(callback): """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that the script did should be reverted here""" -- cgit v1.2.3 From 0b262802b86a55c4f71faf377f2cb1aee2960b63 Mon Sep 17 00:00:00 2001 From: Josh R Date: Thu, 12 Jan 2023 17:31:05 -0800 Subject: add gradient settings to training settings log files --- modules/textual_inversion/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py index 8b1981d5..31e50b64 100644 --- a/modules/textual_inversion/logging.py +++ b/modules/textual_inversion/logging.py @@ -2,7 +2,7 @@ import datetime import json import os -saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} +saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet -- cgit v1.2.3 From a176d89487d92f5a5b152401e5c424b34ff43b96 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 13 Jan 2023 14:32:15 +0300 Subject: print bucket sizes for training without resizing images #6620 fix an error when generating a picture with embedding in it --- modules/textual_inversion/dataset.py | 16 ++++++++++++++++ modules/textual_inversion/image_embedding.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b47414f3..d31963d4 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -118,6 +118,12 @@ class PersonalizedBase(Dataset): self.gradient_step = min(gradient_step, self.length // self.batch_size) self.latent_sampling_method = latent_sampling_method + if len(groups) > 1: + print("Buckets:") + for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]): + print(f" {w}x{h}: {len(ids)}") + print() + def create_text(self, filename_text): text = random.choice(self.lines) tags = filename_text.split(',') @@ -140,8 +146,11 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry + class GroupedBatchSampler(Sampler): def __init__(self, data_source: PersonalizedBase, batch_size: int): + super().__init__(data_source) + n = len(data_source) self.groups = data_source.groups self.len = n_batch = n // batch_size @@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler): self.n_rand_batches = nrb = n_batch - sum(self.base) self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] self.batch_size = batch_size + def __len__(self): return self.len + def __iter__(self): b = self.batch_size + for g in self.groups: shuffle(g) + batches = [] for g in self.groups: batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) for _ in range(self.n_rand_batches): rand_group = choices(self.groups, self.probs)[0] batches.append(choices(rand_group, k=b)) + shuffle(batches) + yield from batches + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index ea653806..5593f88c 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -76,10 +76,10 @@ def insert_image_data_embed(image, data): next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h)) next_size = next_size + ((h*d)-(next_size % (h*d))) - data_np_low.resize(next_size) + data_np_low = np.resize(data_np_low, next_size) data_np_low = data_np_low.reshape((h, -1, d)) - data_np_high.resize(next_size) + data_np_high = np.resize(data_np_high, next_size) data_np_high = data_np_high.reshape((h, -1, d)) edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 853246a6..e23906ca 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -479,7 +479,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ epoch_num = embedding.step // steps_per_epoch epoch_step = embedding.step % steps_per_epoch - description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" + description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" pbar.set_description(description) shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: -- cgit v1.2.3 From 82725f0ac439f7e3b67858d55900e95330bbd326 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 13 Jan 2023 15:04:37 +0300 Subject: fix a bug caused by merge --- modules/textual_inversion/textual_inversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 85210b0e..6939efcc 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -11,6 +11,7 @@ import datetime import csv import safetensors.torch +import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -- cgit v1.2.3 From d753a9df952ea640acbce724e8153356c8b68424 Mon Sep 17 00:00:00 2001 From: Zaprudin Aleksey Date: Fri, 13 Jan 2023 22:25:33 +0500 Subject: fix progress bar behavior for "Prompts from file or textbox" script --- scripts/prompts_from_file.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 2751f98a..1fe10a7c 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -146,7 +146,7 @@ class Script(scripts.Script): else: args = {"prompt": line} - n_iter = args.get("n_iter", 1) + n_iter = args.get("n_iter", p.n_iter) if n_iter != 1: job_count += n_iter else: -- cgit v1.2.3 From cbf4b3472b1da35937ff12c06072214a2e5cbad7 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Fri, 13 Jan 2023 19:18:56 +0100 Subject: Automatic launch argument for AMD GPUs This commit adds a few lines to detect if the system has an AMD gpu and adds an environment variable needed for torch to recognize the gpu. --- webui.sh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index c4d6521d..23629ef9 100755 --- a/webui.sh +++ b/webui.sh @@ -165,5 +165,11 @@ else printf "\n%s\n" "${delimiter}" printf "Launching launch.py..." printf "\n%s\n" "${delimiter}" - exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" + gpu_info=$(lspci | grep VGA) + if echo "$gpu_info" | grep -q "AMD" + then + HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" + else + exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" + fi fi -- cgit v1.2.3 From eaebcf638391071172d504568d661931f7e3c740 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Fri, 13 Jan 2023 19:20:18 +0100 Subject: GPU detection script This commit adds a script that detects which GPU is currently used in Windows and Linux --- detection.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 detection.py diff --git a/detection.py b/detection.py new file mode 100644 index 00000000..eb4db0df --- /dev/null +++ b/detection.py @@ -0,0 +1,45 @@ +# This script detects which GPU is currently used in Windows and Linux +import os +import sys + +def check_gpu(): + # First, check if the `lspci` command is available + if not os.system("which lspci > /dev/null") == 0: + # If the `lspci` command is not available, try the `dxdiag` command on Windows + if os.name == "nt": + # On Windows, run the `dxdiag` command and check the output for the "Card name" field + # Create the dxdiag.txt file + os.system("dxdiag /t dxdiag.txt") + + # Read the dxdiag.txt file + with open("dxdiag.txt", "r") as f: + output = f.read() + + if "Card name" in output: + card_name_start = output.index("Card name: ") + len("Card name: ") + card_name_end = output.index("\n", card_name_start) + card_name = output[card_name_start:card_name_end] + else: + card_name = "Unknown" + print(f"Card name: {card_name}") + os.remove("dxdiag.txt") + if "AMD" in card_name: + return "AMD" + elif "Intel" in card_name: + return "Intel" + elif "NVIDIA" in card_name: + return "NVIDIA" + else: + return "Unknown" + else: + # If the `lspci` command is available, use it to get the GPU vendor and model information + output = os.popen("lspci | grep -i vga").read() + if "AMD" in output: + return "AMD" + elif "Intel" in output: + return "Intel" + elif "NVIDIA" in output: + return "NVIDIA" + else: + return "Unknown" + -- cgit v1.2.3 From a407c9f0147c779865c940cbf62c7019dbc1f7b4 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Fri, 13 Jan 2023 19:22:23 +0100 Subject: Automatic torch install for amd on linux This commit allows the launch script to automatically download rocm's torch version for AMD GPUs using an external GPU detection script. It also prints the operative system and GPU in use. --- launch.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index bcbb792c..668548f1 100644 --- a/launch.py +++ b/launch.py @@ -7,6 +7,7 @@ import shlex import platform import argparse import json +import detection dir_repos = "repositories" dir_extensions = "extensions" @@ -15,6 +16,12 @@ git = os.environ.get('GIT', "git") index_url = os.environ.get('INDEX_URL', "") stored_commit_hash = None +# Get the GPU vendor and the operating system +gpu = detection.check_gpu() +if os.name == "posix": + os_name = platform.uname().system +else: + os_name = os.name def commit_hash(): global stored_commit_hash @@ -173,7 +180,11 @@ def run_extensions_installers(settings_file): def prepare_environment(): - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") + if gpu == "AMD" and os_name !="nt": + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2") + else: + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") + requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -295,6 +306,8 @@ def tests(test_dir): def start(): + print(f"Operating System: {os_name}") + print(f"GPU: {gpu}") print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui if '--nowebui' in sys.argv: -- cgit v1.2.3 From a95f1353089bdeaccd7c266b40cdd79efedfe632 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 09:56:59 +0300 Subject: change hash to sha256 --- .gitignore | 1 + modules/api/api.py | 2 +- modules/api/models.py | 3 +- modules/hashes.py | 72 +++++++++++++++ modules/hypernetworks/hypernetwork.py | 4 +- modules/sd_models.py | 116 ++++++++++++++++--------- modules/shared.py | 2 +- modules/textual_inversion/textual_inversion.py | 6 +- webui.py | 2 + 9 files changed, 158 insertions(+), 50 deletions(-) create mode 100644 modules/hashes.py diff --git a/.gitignore b/.gitignore index 21fa26a7..0b1d17ca 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ notification.mp3 /extensions /test/stdout.txt /test/stderr.txt +/cache.json diff --git a/modules/api/api.py b/modules/api/api.py index 5767ba90..9814bbc2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -371,7 +371,7 @@ class Api: return upscalers def get_sd_models(self): - return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/api/models.py b/modules/api/models.py index c78095ca..1eb1fcf1 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -224,7 +224,8 @@ class UpscalerItem(BaseModel): class SDModelItem(BaseModel): title: str = Field(title="Title") model_name: str = Field(title="Model Name") - hash: str = Field(title="Hash") + hash: Optional[str] = Field(title="Short hash") + sha256: Optional[str] = Field(title="sha256 hash") filename: str = Field(title="Filename") config: str = Field(title="Config file") diff --git a/modules/hashes.py b/modules/hashes.py new file mode 100644 index 00000000..ebfbd90c --- /dev/null +++ b/modules/hashes.py @@ -0,0 +1,72 @@ +import hashlib +import json +import os.path + +import filelock + + +cache_filename = "cache.json" +cache_data = None + + +def dump_cache(): + with filelock.FileLock(cache_filename+".lock"): + with open(cache_filename, "w", encoding="utf8") as file: + json.dump(cache_data, file, indent=4) + + +def cache(subsection): + global cache_data + + if cache_data is None: + with filelock.FileLock(cache_filename+".lock"): + if not os.path.isfile(cache_filename): + cache_data = {} + else: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + + s = cache_data.get(subsection, {}) + cache_data[subsection] = s + + return s + + +def calculate_sha256(filename): + hash_sha256 = hashlib.sha256() + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def sha256(filename, title): + hashes = cache("hashes") + ondisk_mtime = os.path.getmtime(filename) + + if title in hashes: + cached_sha256 = hashes[title].get("sha256", None) + cached_mtime = hashes[title].get("mtime", 0) + + if ondisk_mtime <= cached_mtime and cached_sha256 is not None: + return cached_sha256 + + print(f"Calculating sha256 for {filename}: ", end='') + sha256_value = calculate_sha256(filename) + print(f"{sha256_value}") + + hashes[title] = { + "mtime": ondisk_mtime, + "sha256": sha256_value, + } + + dump_cache() + + return sha256_value + + + + + diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 83cbb4f0..9b5f2e79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -509,7 +509,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if shared.opts.save_training_settings_to_txt: saved_params = dict( - model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), + model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]} ) logging.save_settings_to_file(log_directory, {**saved_params, **locals()}) @@ -737,7 +737,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None try: - hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint = checkpoint.shorthash hypernetwork.sd_checkpoint_name = checkpoint.model_name hypernetwork.name = hypernetwork_name hypernetwork.save(filename) diff --git a/modules/sd_models.py b/modules/sd_models.py index c466f273..7babb9ae 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,17 +14,56 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} +checkpoint_alisases = {} checkpoints_loaded = collections.OrderedDict() + +class CheckpointInfo: + def __init__(self, filename): + self.filename = filename + abspath = os.path.abspath(filename) + + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): + name = abspath.replace(shared.cmd_opts.ckpt_dir, '') + elif abspath.startswith(model_path): + name = abspath.replace(model_path, '') + else: + name = os.path.basename(filename) + + if name.startswith("\\") or name.startswith("/"): + name = name[1:] + + self.title = name + self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + self.hash = model_hash(filename) + self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + self.shorthash = None + self.sha256 = None + + def register(self): + checkpoints_list[self.title] = self + for id in self.ids: + checkpoint_alisases[id] = self + + def calculate_shorthash(self): + self.sha256 = hashes.sha256(self.filename, self.title) + self.shorthash = self.sha256[0:10] + + if self.shorthash not in self.ids: + self.ids += [self.shorthash, self.sha256] + self.register() + + return self.shorthash + + try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -43,10 +82,14 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - convert = lambda name: int(name) if name.isdigit() else name.lower() - alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] - return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) +def checkpoint_tiles(): + def convert(name): + return int(name) if name.isdigit() else name.lower() + + def alphanumeric_key(key): + return [convert(c) for c in re.split('([0-9]+)', key)] + + return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) def find_checkpoint_config(info): @@ -62,48 +105,38 @@ def find_checkpoint_config(info): def list_models(): checkpoints_list.clear() + checkpoint_alisases.clear() model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"]) - def modeltitle(path, shorthash): - abspath = os.path.abspath(path) - - if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): - name = abspath.replace(shared.cmd_opts.ckpt_dir, '') - elif abspath.startswith(model_path): - name = abspath.replace(model_path, '') - else: - name = os.path.basename(path) - - if name.startswith("\\") or name.startswith("/"): - name = name[1:] - - shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] - - return f'{name} [{shorthash}]', shortname - cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): - h = model_hash(cmd_ckpt) - title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) - shared.opts.data['sd_model_checkpoint'] = title + checkpoint_info = CheckpointInfo(cmd_ckpt) + checkpoint_info.register() + + shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) + for filename in model_list: - h = model_hash(filename) - title, short_model_name = modeltitle(filename, h) + checkpoint_info = CheckpointInfo(filename) + checkpoint_info.register() + - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) +def get_closet_checkpoint_match(search_string): + checkpoint_info = checkpoint_alisases.get(search_string, None) + if checkpoint_info is not None: + return + found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) + if found: + return found[0] -def get_closet_checkpoint_match(searchString): - applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title)) - if len(applicable) > 0: - return applicable[0] return None def model_hash(filename): + """old hash that only looks at a small part of the file and is prone to collisions""" + try: with open(filename, "rb") as file: import hashlib @@ -119,7 +152,7 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint - checkpoint_info = checkpoints_list.get(model_checkpoint, None) + checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info @@ -189,9 +222,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info, vae_file="auto"): - checkpoint_file = checkpoint_info.filename - sd_model_hash = checkpoint_info.hash +def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): + sd_model_hash = checkpoint_info.calculate_shorthash() cache_enabled = shared.opts.sd_checkpoint_cache > 0 @@ -201,9 +233,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.load_state_dict(checkpoints_loaded[checkpoint_info]) else: # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") - sd = read_state_dict(checkpoint_file) + sd = read_state_dict(checkpoint_info.filename) model.load_state_dict(sd, strict=False) del sd @@ -235,14 +267,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoints_loaded.popitem(last=False) # LRU model.sd_model_hash = sd_model_hash - model.sd_model_checkpoint = checkpoint_file + model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file) sd_vae.load_vae(model, vae_file) diff --git a/modules/shared.py b/modules/shared.py index b90ded52..d74c069d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,7 +428,7 @@ options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), + "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6939efcc..63935878 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -407,7 +407,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) + save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) latent_sampling_method = ds.latent_sampling_method @@ -584,7 +584,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) + footer_mid = '[{}]'.format(checkpoint.shorthash) footer_right = '{}v {}s'.format(vectorSize, steps_done) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) @@ -626,7 +626,7 @@ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, r old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None try: - embedding.sd_checkpoint = checkpoint.hash + embedding.sd_checkpoint = checkpoint.shorthash embedding.sd_checkpoint_name = checkpoint.model_name if remove_cached_checksum: embedding.cached_checksum = None diff --git a/webui.py b/webui.py index 47d372c7..1fff80da 100644 --- a/webui.py +++ b/webui.py @@ -78,6 +78,8 @@ def initialize(): print("Stable diffusion model failed to load, exiting", file=sys.stderr) exit(1) + shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) -- cgit v1.2.3 From f9ac3352cb66ce2bc0aa4325130fc7267fb35e4f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 10:25:21 +0300 Subject: change hypernets to use sha256 hashes --- modules/hypernetworks/hypernetwork.py | 40 ++++++++++++++++++++--------------- modules/processing.py | 2 +- modules/sd_models.py | 2 +- modules/shared.py | 1 + 4 files changed, 26 insertions(+), 19 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 9b5f2e79..3aebefa8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers +from modules import devices, processing, sd_models, shared, sd_samplers, hashes from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -225,7 +225,7 @@ class Hypernetwork: torch.save(state_dict, filename) if shared.opts.save_optimizer_state and self.optimizer_state_dict: - optimizer_saved_dict['hash'] = sd_models.model_hash(filename) + optimizer_saved_dict['hash'] = self.shorthash() optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict torch.save(optimizer_saved_dict, filename + '.optim') @@ -237,32 +237,33 @@ class Hypernetwork: state_dict = torch.load(filename, map_location='cpu') self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) - print(self.layer_structure) - optional_info = state_dict.get('optional_info', None) - if optional_info is not None: - print(f"INFO:\n {optional_info}\n") - self.optional_info = optional_info + self.optional_info = state_dict.get('optional_info', None) self.activation_func = state_dict.get('activation_func', None) - print(f"Activation function is {self.activation_func}") self.weight_init = state_dict.get('weight_initialization', 'Normal') - print(f"Weight initialization is {self.weight_init}") self.add_layer_norm = state_dict.get('is_layer_norm', False) - print(f"Layer norm is set to {self.add_layer_norm}") self.dropout_structure = state_dict.get('dropout_structure', None) self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False) - print(f"Dropout usage is set to {self.use_dropout}" ) self.activate_output = state_dict.get('activate_output', True) - print(f"Activate last layer is set to {self.activate_output}") self.last_layer_dropout = state_dict.get('last_layer_dropout', False) # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0. if self.dropout_structure is None: - print("Using previous dropout structure") self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout) - print(f"Dropout structure is set to {self.dropout_structure}") - optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {} + if shared.opts.print_hypernet_extra: + if self.optional_info is not None: + print(f" INFO:\n {self.optional_info}\n") - if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None): + print(f" Layer structure: {self.layer_structure}") + print(f" Activation function: {self.activation_func}") + print(f" Weight initialization: {self.weight_init}") + print(f" Layer norm: {self.add_layer_norm}") + print(f" Dropout usage: {self.use_dropout}" ) + print(f" Activate last layer: {self.activate_output}") + print(f" Dropout structure: {self.dropout_structure}") + + optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {} + + if self.shorthash() == optimizer_saved_dict.get('hash', None): self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) else: self.optimizer_state_dict = None @@ -289,6 +290,11 @@ class Hypernetwork: self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) self.eval() + def shorthash(self): + sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}') + + return sha256[0:10] + def list_hypernetworks(path): res = {} @@ -296,7 +302,7 @@ def list_hypernetworks(path): name = os.path.splitext(os.path.basename(filename))[0] # Prevent a hypothetical "None.pt" from being listed. if name != "None": - res[name + f"({sd_models.model_hash(filename)})"] = filename + res[name] = filename return res diff --git a/modules/processing.py b/modules/processing.py index ae04cab7..849f6b19 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -437,7 +437,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), - "Hypernet hash": (None if shared.loaded_hypernetwork is None else sd_models.model_hash(shared.loaded_hypernetwork.filename)), + "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()), "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), diff --git a/modules/sd_models.py b/modules/sd_models.py index 7babb9ae..8f00191c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -125,7 +125,7 @@ def list_models(): def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_alisases.get(search_string, None) if checkpoint_info is not None: - return + return checkpoint_info found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) if found: diff --git a/modules/shared.py b/modules/shared.py index d74c069d..a6c61db3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -361,6 +361,7 @@ options_templates.update(options_section(('system', "System"), { "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), + "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), })) options_templates.update(options_section(('training', "Training"), { -- cgit v1.2.3 From febd2b722e80959b89a0e5966a159b4eb430c5a5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 13:37:55 +0300 Subject: update key to use with checkpoints' sha256 in cache --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 8f00191c..1fe6d11b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -54,7 +54,7 @@ class CheckpointInfo: checkpoint_alisases[id] = self def calculate_shorthash(self): - self.sha256 = hashes.sha256(self.filename, self.title) + self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title) self.shorthash = self.sha256[0:10] if self.shorthash not in self.ids: -- cgit v1.2.3 From 6eb72fd13f34d94d5459290dd1a0bf0e9ddeda82 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 13:38:10 +0300 Subject: bump gradio to 3.16.1 --- modules/ui.py | 13 ++++++----- requirements.txt | 2 +- requirements_versions.txt | 2 +- style.css | 57 ++++++++++++++++++++++++++++++++--------------- webui.py | 3 +-- 5 files changed, 49 insertions(+), 28 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index e86a624b..202e84e5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -605,7 +605,7 @@ def create_ui(): setup_progressbar(progressbar, txt2img_preview, 'txt2img') with gr.Row().style(equal_height=False): - with gr.Column(variant='panel', elem_id="txt2img_settings"): + with gr.Column(variant='compact', elem_id="txt2img_settings"): for category in ordered_ui_categories(): if category == "sampler": steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") @@ -794,7 +794,7 @@ def create_ui(): setup_progressbar(progressbar, img2img_preview, 'img2img') with FormRow().style(equal_height=False): - with gr.Column(variant='panel', elem_id="img2img_settings"): + with gr.Column(variant='compact', elem_id="img2img_settings"): with gr.Tabs(elem_id="mode_img2img"): with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480) @@ -1026,7 +1026,7 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): + with gr.Column(variant='compact'): with gr.Tabs(elem_id="mode_extras"): with gr.TabItem('Single Image', elem_id="extras_single_tab"): extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") @@ -1127,8 +1127,8 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as modelmerger_interface: with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + with gr.Column(variant='compact'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") with FormRow(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") @@ -1150,7 +1150,8 @@ def create_ui(): config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') + with gr.Row(): + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) diff --git a/requirements.txt b/requirements.txt index e1dbf8e5..6cdea781 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ fairscale==0.4.4 fonts font-roboto gfpgan -gradio==3.15.0 +gradio==3.16.1 invisible-watermark numpy omegaconf diff --git a/requirements_versions.txt b/requirements_versions.txt index d2899292..cc06d2b4 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -3,7 +3,7 @@ transformers==4.19.2 accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 -gradio==3.15.0 +gradio==3.16.1 numpy==1.23.3 Pillow==9.4.0 realesrgan==0.3.0 diff --git a/style.css b/style.css index ffd6307f..14b15191 100644 --- a/style.css +++ b/style.css @@ -20,7 +20,7 @@ padding-right: 0.25em; margin: 0.1em 0; opacity: 0%; - cursor: default; + cursor: default; } .output-html p {margin: 0 0.5em;} @@ -221,7 +221,10 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ background-color: rgb(31, 41, 55); - box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55); + box-shadow: none; + border: 1px solid rgba(128, 128, 128, 0.1); + border-radius: 6px; + padding: 0.1em 0.5em; } #txt2img_column_batch, #img2img_column_batch{ @@ -371,7 +374,7 @@ input[type="range"]{ grid-area: tile; } -.modalClose, +.modalClose, .modalZoom, .modalTileImage { color: white; @@ -509,29 +512,20 @@ input[type="range"]{ } #quicksettings > div{ - border: none; - background: none; - flex: unset; - gap: 1em; -} - -#quicksettings > div > div{ - max-width: 32em; + max-width: 24em; min-width: 24em; padding: 0; + border: none; + box-shadow: none; + background: none; } -#quicksettings > div > div > div > div > label > span { +#quicksettings > div > div > div > label > span { position: relative; margin-right: 9em; margin-bottom: -1em; } -#quicksettings > div > div > label > span { - position: relative; - margin-bottom: -1em; -} - canvas[key="mask"] { z-index: 12 !important; filter: invert(); @@ -666,7 +660,10 @@ footer { font-weight: bold; } -#txt2img_checkboxes > div > div{ +#txt2img_checkboxes, #img2img_checkboxes{ + margin-bottom: 0.5em; +} +#txt2img_checkboxes > div > div, #img2img_checkboxes > div > div{ flex: 0; white-space: nowrap; min-width: auto; @@ -676,6 +673,30 @@ footer { opacity: 0.5; } +.gr-compact { + border: none; + padding-top: 1em; +} + +.dark .gr-compact{ + background-color: rgb(31 41 55 / var(--tw-bg-opacity)); + margin-left: 0.8em; +} + +.gr-compact > *{ + margin-top: 0.5em !important; +} + +.gr-compact .gr-block, .gr-compact .gr-form{ + border: none; + box-shadow: none; +} + +.gr-compact .gr-box{ + border-radius: .5rem !important; + border-width: 1px !important; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running diff --git a/webui.py b/webui.py index 1fff80da..84159515 100644 --- a/webui.py +++ b/webui.py @@ -157,7 +157,7 @@ def webui(): shared.demo = modules.ui.create_ui() - app, local_url, share_url = shared.demo.queue(default_enabled=False).launch( + app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=server_name, server_port=cmd_opts.port, @@ -185,7 +185,6 @@ def webui(): create_api(app) modules.script_callbacks.app_started_callback(shared.demo, app) - modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo) print('Restarting UI...') -- cgit v1.2.3 From 5f8685237ed6427c9a8e502124074c740ea7696a Mon Sep 17 00:00:00 2001 From: bbc_mc Date: Sat, 14 Jan 2023 20:00:00 +0900 Subject: Exclude clip index from merge --- modules/extras.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/modules/extras.py b/modules/extras.py index a03d558e..22668fcd 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -326,8 +326,14 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam print("Merging...") + chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: + + if key in chckpoint_dict_skip_on_merge: + continue + a = theta_0[key] b = theta_1[key] @@ -352,6 +358,10 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam # I believe this part should be discarded, but I'll leave it for now until I am sure for key in theta_1.keys(): if 'model' in key and key not in theta_0: + + if key in chckpoint_dict_skip_on_merge: + continue + theta_0[key] = theta_1[key] if save_as_half: theta_0[key] = theta_0[key].half() -- cgit v1.2.3 From 54fa77facc1849fbbfe61c1ca6d99b117d609d67 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Sat, 14 Jan 2023 12:10:45 +0100 Subject: Fix detection script on macos This fixes the script on macos --- detection.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/detection.py b/detection.py index eb4db0df..442c4be5 100644 --- a/detection.py +++ b/detection.py @@ -31,6 +31,8 @@ def check_gpu(): return "NVIDIA" else: return "Unknown" + else: + return "Unknown" else: # If the `lspci` command is available, use it to get the GPU vendor and model information output = os.popen("lspci | grep -i vga").read() -- cgit v1.2.3 From 865228a83736bea9ede33e98041f2a7d0ca5daaa Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 14:56:39 +0300 Subject: change style dropdowns to multiselect --- javascript/localization.js | 6 ++---- modules/img2img.py | 4 ++-- modules/styles.py | 12 ++++++++--- modules/txt2img.py | 4 ++-- modules/ui.py | 53 ++++++++++++++++++++++++++-------------------- style.css | 15 ++++++++++--- 6 files changed, 57 insertions(+), 37 deletions(-) diff --git a/javascript/localization.js b/javascript/localization.js index f92d2d24..bf9e1506 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -10,10 +10,8 @@ ignore_ids_for_localization={ modelmerger_tertiary_model_name: 'OPTION', train_embedding: 'OPTION', train_hypernetwork: 'OPTION', - txt2img_style_index: 'OPTION', - txt2img_style2_index: 'OPTION', - img2img_style_index: 'OPTION', - img2img_style2_index: 'OPTION', + txt2img_styles: 'OPTION', + img2img_styles 'OPTION', setting_random_artist_categories: 'SPAN', setting_face_restoration_model: 'SPAN', setting_realesrgan_enabled_models: 'SPAN', diff --git a/modules/img2img.py b/modules/img2img.py index f62783c6..79382cc1 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_batch = mode == 5 if mode == 0: # img2img @@ -101,7 +101,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids, prompt=prompt, negative_prompt=negative_prompt, - styles=[prompt_style, prompt_style2], + styles=prompt_styles, seed=seed, subseed=subseed, subseed_strength=subseed_strength, diff --git a/modules/styles.py b/modules/styles.py index ce6e71ca..990d5623 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -40,12 +40,18 @@ def apply_styles_to_prompt(prompt, styles): class StyleDatabase: def __init__(self, path: str): self.no_style = PromptStyle("None", "", "") - self.styles = {"None": self.no_style} + self.styles = {} + self.path = path - if not os.path.exists(path): + self.reload() + + def reload(self): + self.styles.clear() + + if not os.path.exists(self.path): return - with open(path, "r", encoding="utf-8-sig", newline='') as file: + with open(self.path, "r", encoding="utf-8-sig", newline='') as file: reader = csv.DictReader(file) for row in reader: # Support loading old CSV format with "name, text"-columns diff --git a/modules/txt2img.py b/modules/txt2img.py index 38b5f591..5a71793b 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,13 +8,13 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, prompt=prompt, - styles=[prompt_style, prompt_style2], + styles=prompt_styles, negative_prompt=negative_prompt, seed=seed, subseed=subseed, diff --git a/modules/ui.py b/modules/ui.py index 202e84e5..db198a47 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -180,7 +180,7 @@ def add_style(name: str, prompt: str, negative_prompt: str): # reserialize all styles every time we save them shared.prompt_styles.save_styles(shared.styles_filename) - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): @@ -197,11 +197,11 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) +def apply_styles(prompt, prompt_neg, styles): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] def interrogate(image): @@ -374,13 +374,10 @@ def create_toprow(is_img2img): ) with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) + create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button def setup_progressbar(*args, **kwargs): @@ -590,7 +587,7 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -684,8 +681,7 @@ def create_ui(): inputs=[ txt2img_prompt, txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, + txt2img_prompt_styles, steps, sampler_index, restore_faces, @@ -780,7 +776,7 @@ def create_ui(): modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -921,8 +917,7 @@ def create_ui(): dummy_component, img2img_prompt, img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, + img2img_prompt_styles, init_img, sketch, init_img_with_mask, @@ -977,7 +972,7 @@ def create_ui(): ) prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles] style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): @@ -987,15 +982,15 @@ def create_ui(): # Have to pass empty dummy component here, because the JavaScript and Python function have to accept # the same number of parameters, but we only know the style-name after the JavaScript prompt inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], + outputs=[txt2img_prompt_styles, img2img_prompt_styles], ) - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): button.click( fn=apply_styles, _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], + inputs=[prompt, negative_prompt, styles], + outputs=[prompt, negative_prompt, styles], ) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) @@ -1530,6 +1525,7 @@ def create_ui(): previous_section = None current_tab = None + current_row = None with gr.Tabs(elem_id="settings"): for i, (k, item) in enumerate(opts.data_labels.items()): section_must_be_skipped = item.section[0] is None @@ -1538,10 +1534,14 @@ def create_ui(): elem_id, text = item.section if current_tab is not None: + current_row.__exit__() current_tab.__exit__() + gr.Group() current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) current_tab.__enter__() + current_row = gr.Column(variant='compact') + current_row.__enter__() previous_section = item.section @@ -1556,6 +1556,7 @@ def create_ui(): components.append(component) if current_tab is not None: + current_row.__exit__() current_tab.__exit__() with gr.TabItem("Actions"): @@ -1774,7 +1775,13 @@ def create_ui(): apply_field(x, 'value') if type(x) == gr.Dropdown: - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + def check_dropdown(val): + if x.multiselect: + return all([value in x.choices for value in val]) + else: + return val in x.choices + + apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None)) visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") diff --git a/style.css b/style.css index 14b15191..c1ebd501 100644 --- a/style.css +++ b/style.css @@ -114,6 +114,7 @@ min-width: unset !important; flex-grow: 0 !important; padding: 0.4em 0; + gap: 0; } #roll_col > button { @@ -141,10 +142,14 @@ min-width: 8em !important; } -#txt2img_style_index, #txt2img_style2_index, #img2img_style_index, #img2img_style2_index{ +#txt2img_styles, #img2img_styles{ margin-top: 1em; } +#txt2img_styles ul, #img2img_styles ul{ + max-height: 35em; +} + .gr-form{ background: transparent; } @@ -154,10 +159,14 @@ margin-bottom: 0; } -#toprow div{ +#toprow div.gr-box, #toprow div.gr-form{ border: none; gap: 0; background: transparent; + box-shadow: none; +} +#toprow div{ + gap: 0; } #resize_mode{ @@ -615,7 +624,7 @@ canvas[key="mask"] { max-width: 2.5em; min-width: 2.5em !important; height: 2.4em; - margin: 0.55em 0; + margin: 1.6em 0 0 0; } #quicksettings .gr-button-tool{ -- cgit v1.2.3 From 08c6f009a5ee92dd3218a942c08e8337c26352be Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 15:55:40 +0300 Subject: load hashes from cache for checkpoints that have them add checkpoint hash to footer --- javascript/ui.js | 25 ++++++++++++++++--------- modules/hashes.py | 26 +++++++++++++++++++------- modules/sd_models.py | 9 ++++++--- modules/shared.py | 1 + modules/ui.py | 2 ++ script.js | 4 ++++ 6 files changed, 48 insertions(+), 19 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index a41dd26f..1e04a8f4 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -143,14 +143,6 @@ function confirm_clear_prompt(prompt, negative_prompt) { opts = {} -function apply_settings(jsdata){ - console.log(jsdata) - - opts = JSON.parse(jsdata) - - return jsdata -} - onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; @@ -160,7 +152,7 @@ onUiUpdate(function(){ textarea = json_elem.querySelector('textarea') jsdata = textarea.value opts = JSON.parse(jsdata) - + executeCallbacks(optionsChangedCallbacks); Object.defineProperty(textarea, 'value', { set: function(newValue) { @@ -171,6 +163,8 @@ onUiUpdate(function(){ if (oldValue != newValue) { opts = JSON.parse(textarea.value) } + + executeCallbacks(optionsChangedCallbacks); }, get: function() { var valueProp = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value'); @@ -201,6 +195,19 @@ onUiUpdate(function(){ } }) + +onOptionsChanged(function(){ + elem = gradioApp().getElementById('sd_checkpoint_hash') + sd_checkpoint_hash = opts.sd_checkpoint_hash || "" + shorthash = sd_checkpoint_hash.substr(0,10) + + if(elem && elem.textContent != shorthash){ + elem.textContent = shorthash + elem.title = sd_checkpoint_hash + elem.href = "https://google.com/search?q=" + sd_checkpoint_hash + } +}) + let txt2img_textarea, img2img_textarea = undefined; let wait_time = 800 let token_timeout; diff --git a/modules/hashes.py b/modules/hashes.py index ebfbd90c..14231771 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -42,23 +42,35 @@ def calculate_sha256(filename): return hash_sha256.hexdigest() -def sha256(filename, title): +def sha256_from_cache(filename, title): hashes = cache("hashes") ondisk_mtime = os.path.getmtime(filename) - if title in hashes: - cached_sha256 = hashes[title].get("sha256", None) - cached_mtime = hashes[title].get("mtime", 0) + if title not in hashes: + return None + + cached_sha256 = hashes[title].get("sha256", None) + cached_mtime = hashes[title].get("mtime", 0) + + if ondisk_mtime > cached_mtime or cached_sha256 is None: + return None + + return cached_sha256 + + +def sha256(filename, title): + hashes = cache("hashes") - if ondisk_mtime <= cached_mtime and cached_sha256 is not None: - return cached_sha256 + sha256_value = sha256_from_cache(filename, title) + if sha256_value is not None: + return sha256_value print(f"Calculating sha256 for {filename}: ", end='') sha256_value = calculate_sha256(filename) print(f"{sha256_value}") hashes[title] = { - "mtime": ondisk_mtime, + "mtime": os.path.getmtime(filename), "sha256": sha256_value, } diff --git a/modules/sd_models.py b/modules/sd_models.py index 1fe6d11b..e5a0bc63 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -44,9 +44,11 @@ class CheckpointInfo: self.title = name self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) - self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] - self.shorthash = None - self.sha256 = None + + self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title) + self.shorthash = self.sha256[0:10] if self.sha256 else None + + self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self @@ -269,6 +271,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info + shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 model.logvar = model.logvar.to(devices.device) # fix for training diff --git a/modules/shared.py b/modules/shared.py index a6c61db3..c9988d4d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -458,6 +458,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable those extensions"), + "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), })) options_templates.update() diff --git a/modules/ui.py b/modules/ui.py index e86a624b..2625ae32 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1841,4 +1841,6 @@ xformers: {xformers_version} gradio: {gr.__version__}  •  commit: {short_commit} + •  +checkpoint: N/A """ diff --git a/script.js b/script.js index 21960d91..3345e32b 100644 --- a/script.js +++ b/script.js @@ -14,6 +14,7 @@ function get_uiCurrentTabContent() { uiUpdateCallbacks = [] uiTabChangeCallbacks = [] +optionsChangedCallbacks = [] let uiCurrentTab = null function onUiUpdate(callback){ @@ -22,6 +23,9 @@ function onUiUpdate(callback){ function onUiTabChange(callback){ uiTabChangeCallbacks.push(callback) } +function onOptionsChanged(callback){ + optionsChangedCallbacks.push(callback) +} function runCallback(x, m){ try { -- cgit v1.2.3 From f94a215abed85b34ae978853078812801d3e7738 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 16:29:23 +0300 Subject: add an option to choose what you want to see in live preview (Live preview subject) and moves live preview settings to its own tab --- modules/sd_samplers.py | 15 ++++++++++----- modules/shared.py | 13 +++++++++---- modules/ui_progress.py | 2 +- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 01221b89..7616fded 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -138,7 +138,7 @@ def samples_to_image_grid(samples, approximation=None): def store_latent(decoded): state.current_latent = decoded - if opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: + if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if not shared.parallel_processing_allowed: shared.state.current_image = sample_to_image(decoded) @@ -243,7 +243,7 @@ class VanillaStableDiffusionSampler: self.nmask = p.nmask if hasattr(p, 'nmask') else None def adjust_steps_if_invalid(self, p, num_steps): - if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): + if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'): valid_step = 999 / (1000 // num_steps) if valid_step == floor(valid_step): return int(valid_step) + 1 @@ -266,8 +266,7 @@ class VanillaStableDiffusionSampler: if image_conditioning is not None: conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - + samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) return samples @@ -352,6 +351,11 @@ class CFGDenoiser(torch.nn.Module): x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) + if opts.live_preview_content == "Prompt": + store_latent(x_out[0:uncond.shape[0]]) + elif opts.live_preview_content == "Negative prompt": + store_latent(x_out[-uncond.shape[0]:]) + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) if self.mask is not None: @@ -423,7 +427,8 @@ class KDiffusionSampler: def callback_state(self, d): step = d['i'] latent = d["denoised"] - store_latent(latent) + if opts.live_preview_content == "Combined": + store_latent(latent) self.last_latent = latent if self.stop_at is not None and step > self.stop_at: diff --git a/modules/shared.py b/modules/shared.py index c9988d4d..e0ec3136 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -176,7 +176,7 @@ class State: self.interrupted = True def nextjob(self): - if opts.show_progress_every_n_steps == -1: + if opts.live_previews_enable and opts.show_progress_every_n_steps == -1: self.do_set_current_image() self.job_no += 1 @@ -224,7 +224,7 @@ class State: if not parallel_processing_allowed: return - if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.show_progress_every_n_steps > 0: + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable: self.do_set_current_image() def do_set_current_image(self): @@ -423,8 +423,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set to 0 to disable. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), - "show_progress_type": OptionInfo("Full", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), @@ -444,6 +442,13 @@ options_templates.update(options_section(('ui', "User interface"), { 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) +options_templates.update(options_section(('ui', "Live previews"), { + "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), + "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), + "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), + "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), +})) + options_templates.update(options_section(('sampler-params', "Sampler parameters"), { "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in list_samplers()]}), "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), diff --git a/modules/ui_progress.py b/modules/ui_progress.py index 592fda55..7cd312e4 100644 --- a/modules/ui_progress.py +++ b/modules/ui_progress.py @@ -52,7 +52,7 @@ def check_progress_call(id_part): image = gr.update(visible=False) preview_visibility = gr.update(visible=False) - if opts.show_progress_every_n_steps != 0: + if opts.live_previews_enable: shared.state.set_current_image() image = shared.state.current_image -- cgit v1.2.3 From 69781031e7473e020b3af4461fdceb20130e56ab Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 16:45:39 +0300 Subject: simplify expression in prompts from file script --- scripts/prompts_from_file.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 1fe10a7c..f3e711d7 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -146,11 +146,7 @@ class Script(scripts.Script): else: args = {"prompt": line} - n_iter = args.get("n_iter", p.n_iter) - if n_iter != 1: - job_count += n_iter - else: - job_count += 1 + job_count += args.get("n_iter", p.n_iter) jobs.append(args) -- cgit v1.2.3 From 934cba0f4ca3e80a2079a657ebb6ca8c1ee2d10b Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Sat, 14 Jan 2023 15:43:29 +0100 Subject: Delete detection.py --- detection.py | 47 ----------------------------------------------- 1 file changed, 47 deletions(-) delete mode 100644 detection.py diff --git a/detection.py b/detection.py deleted file mode 100644 index 442c4be5..00000000 --- a/detection.py +++ /dev/null @@ -1,47 +0,0 @@ -# This script detects which GPU is currently used in Windows and Linux -import os -import sys - -def check_gpu(): - # First, check if the `lspci` command is available - if not os.system("which lspci > /dev/null") == 0: - # If the `lspci` command is not available, try the `dxdiag` command on Windows - if os.name == "nt": - # On Windows, run the `dxdiag` command and check the output for the "Card name" field - # Create the dxdiag.txt file - os.system("dxdiag /t dxdiag.txt") - - # Read the dxdiag.txt file - with open("dxdiag.txt", "r") as f: - output = f.read() - - if "Card name" in output: - card_name_start = output.index("Card name: ") + len("Card name: ") - card_name_end = output.index("\n", card_name_start) - card_name = output[card_name_start:card_name_end] - else: - card_name = "Unknown" - print(f"Card name: {card_name}") - os.remove("dxdiag.txt") - if "AMD" in card_name: - return "AMD" - elif "Intel" in card_name: - return "Intel" - elif "NVIDIA" in card_name: - return "NVIDIA" - else: - return "Unknown" - else: - return "Unknown" - else: - # If the `lspci` command is available, use it to get the GPU vendor and model information - output = os.popen("lspci | grep -i vga").read() - if "AMD" in output: - return "AMD" - elif "Intel" in output: - return "Intel" - elif "NVIDIA" in output: - return "NVIDIA" - else: - return "Unknown" - -- cgit v1.2.3 From 629612935372aa7faeb764b5660431e49b93de24 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Sat, 14 Jan 2023 15:45:07 +0100 Subject: Revert detection code --- launch.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/launch.py b/launch.py index 668548f1..bcbb792c 100644 --- a/launch.py +++ b/launch.py @@ -7,7 +7,6 @@ import shlex import platform import argparse import json -import detection dir_repos = "repositories" dir_extensions = "extensions" @@ -16,12 +15,6 @@ git = os.environ.get('GIT', "git") index_url = os.environ.get('INDEX_URL', "") stored_commit_hash = None -# Get the GPU vendor and the operating system -gpu = detection.check_gpu() -if os.name == "posix": - os_name = platform.uname().system -else: - os_name = os.name def commit_hash(): global stored_commit_hash @@ -180,11 +173,7 @@ def run_extensions_installers(settings_file): def prepare_environment(): - if gpu == "AMD" and os_name !="nt": - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2") - else: - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") - + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -306,8 +295,6 @@ def tests(test_dir): def start(): - print(f"Operating System: {os_name}") - print(f"GPU: {gpu}") print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui if '--nowebui' in sys.argv: -- cgit v1.2.3 From 6192a222bf7131771a2cd7655a64a5b24a1e6e2e Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Sat, 14 Jan 2023 15:46:23 +0100 Subject: Export TORCH_COMMAND for AMD from the webui --- webui.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/webui.sh b/webui.sh index 23629ef9..35f52f2a 100755 --- a/webui.sh +++ b/webui.sh @@ -168,6 +168,7 @@ else gpu_info=$(lspci | grep VGA) if echo "$gpu_info" | grep -q "AMD" then + export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" else exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" -- cgit v1.2.3 From c4ba34928ec7f977585494f0fa5925496c887698 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Sat, 14 Jan 2023 15:58:50 +0100 Subject: Quick format fix --- webui.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index 35f52f2a..fcba6b7d 100755 --- a/webui.sh +++ b/webui.sh @@ -168,7 +168,7 @@ else gpu_info=$(lspci | grep VGA) if echo "$gpu_info" | grep -q "AMD" then - export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" + export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" else exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" -- cgit v1.2.3 From fad850fc3d33e7cda2ce4b3a32ab7976c313db53 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 14 Jan 2023 11:18:05 -0500 Subject: add server_start to shared.state --- modules/shared.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/shared.py b/modules/shared.py index e0ec3136..ef93637c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -168,6 +168,7 @@ class State: textinfo = None time_start = None need_restart = False + server_start = None def skip(self): self.skipped = True @@ -241,6 +242,7 @@ class State: state = State() +state.server_start = time.time() artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) -- cgit v1.2.3 From a5bbcd215304e0c83ab2b9fe7f172f88536d7629 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 19:56:09 +0300 Subject: fix bug with "Ignore selected VAE for..." option completely disabling VAE election rework VAE resolving code to be more simple --- modules/sd_models.py | 6 +- modules/sd_vae.py | 194 ++++++++++++++++++++------------------------------- modules/shared.py | 4 +- scripts/xy_grid.py | 27 ++++--- 4 files changed, 95 insertions(+), 136 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index e5a0bc63..6a681cef 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -224,7 +224,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): +def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_model_hash = checkpoint_info.calculate_shorthash() cache_enabled = shared.opts.sd_checkpoint_cache > 0 @@ -277,8 +277,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file) - sd_vae.load_vae(model, vae_file) + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + sd_vae.load_vae(model, vae_file, vae_source) def enable_midas_autodownload(): diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0a49daa1..6ea92711 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -9,23 +9,9 @@ import glob from copy import deepcopy -model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) -vae_dir = "VAE" -vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) - - +vae_path = os.path.abspath(os.path.join(models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - - -default_vae_dict = {"auto": "auto", "None": None, None: None} -default_vae_list = ["auto", "None"] - - -default_vae_values = [default_vae_dict[x] for x in default_vae_list] -vae_dict = dict(default_vae_dict) -vae_list = list(default_vae_list) -first_load = True +vae_dict = {} base_vae = None @@ -64,100 +50,69 @@ def restore_base_vae(model): def get_filename(filepath): - return os.path.splitext(os.path.basename(filepath))[0] - - -def refresh_vae_list(vae_path=vae_path, model_path=model_path): - global vae_dict, vae_list - res = {} - candidates = [ - *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), + return os.path.basename(filepath) + + +def refresh_vae_list(): + vae_dict.clear() + + paths = [ + os.path.join(sd_models.model_path, '**/*.vae.ckpt'), + os.path.join(sd_models.model_path, '**/*.vae.pt'), + os.path.join(sd_models.model_path, '**/*.vae.safetensors'), + os.path.join(vae_path, '**/*.ckpt'), + os.path.join(vae_path, '**/*.pt'), + os.path.join(vae_path, '**/*.safetensors'), ] - if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): - candidates.append(shared.cmd_opts.vae_path) + + if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir): + paths += [ + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), + ] + + candidates = [] + for path in paths: + candidates += glob.iglob(path, recursive=True) + for filepath in candidates: name = get_filename(filepath) - res[name] = filepath - vae_list.clear() - vae_list.extend(default_vae_list) - vae_list.extend(list(res.keys())) - vae_dict.clear() - vae_dict.update(res) - vae_dict.update(default_vae_dict) - return vae_list - - -def get_vae_from_settings(vae_file="auto"): - # else, we load from settings, if not set to be default - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print(f"Selected VAE doesn't exist: {vae_file}") - return vae_file - - -def resolve_vae(checkpoint_file=None, vae_file="auto"): - global first_load, vae_dict, vae_list - - # if vae_file argument is provided, it takes priority, but not saved - if vae_file and vae_file not in default_vae_list: - if not os.path.isfile(vae_file): - print(f"VAE provided as function argument doesn't exist: {vae_file}") - vae_file = "auto" - # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported - if first_load and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - shared.opts.data['sd_vae'] = get_filename(vae_file) - else: - print(f"VAE provided as command line argument doesn't exist: {vae_file}") - # fallback to selector in settings, if vae selector not set to act as default fallback - if not shared.opts.sd_vae_as_default: - vae_file = get_vae_from_settings(vae_file) - # vae-path cmd arg takes priority for auto - if vae_file == "auto" and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - print(f"Using VAE provided as command line argument: {vae_file}") - # if still not found, try look for ".vae.pt" beside model - model_path = os.path.splitext(checkpoint_file)[0] - if vae_file == "auto": - vae_file_try = model_path + ".vae.pt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.ckpt" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.ckpt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.safetensors" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.safetensors" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # No more fallbacks for auto - if vae_file == "auto": - vae_file = None - # Last check, just because - if vae_file and not os.path.exists(vae_file): - vae_file = None - - return vae_file - - -def load_vae(model, vae_file=None): - global first_load, vae_dict, vae_list, loaded_vae_file + vae_dict[name] = filepath + + +def find_vae_near_checkpoint(checkpoint_file): + checkpoint_path = os.path.splitext(checkpoint_file)[0] + for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: + if os.path.isfile(vae_location): + return vae_location + + return None + + +def resolve_vae(checkpoint_file): + if shared.cmd_opts.vae_path is not None: + return shared.cmd_opts.vae_path, 'from commandline argument' + + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"): + return vae_near_checkpoint, 'found near the checkpoint' + + if shared.opts.sd_vae == "None": + return None, None + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return vae_from_options, 'specified in settings' + + if shared.opts.sd_vae != "Automatic": + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + + return None, None + + +def load_vae(model, vae_file=None, vae_source="from unknown source"): + global vae_dict, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -165,12 +120,12 @@ def load_vae(model, vae_file=None): if vae_file: if cache_enabled and vae_file in checkpoints_loaded: # use vae checkpoint cache - print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") store_base_vae(model) _load_vae_dict(model, checkpoints_loaded[vae_file]) else: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") + assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" + print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) @@ -191,14 +146,12 @@ def load_vae(model, vae_file=None): vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file - vae_list.append(vae_opt) + elif loaded_vae_file: restore_base_vae(model) loaded_vae_file = vae_file - first_load = False - # don't call this from outside def _load_vae_dict(model, vae_dict_1): @@ -211,7 +164,10 @@ def clear_loaded_vae(): loaded_vae_file = None -def reload_vae_weights(sd_model=None, vae_file="auto"): +unspecified = object() + + +def reload_vae_weights(sd_model=None, vae_file=unspecified): from modules import lowvram, devices, sd_hijack if not sd_model: @@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): checkpoint_info = sd_model.sd_checkpoint_info checkpoint_file = checkpoint_info.filename - vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if vae_file == unspecified: + vae_file, vae_source = resolve_vae(checkpoint_file) + else: + vae_source = "from function argument" if loaded_vae_file == vae_file: return @@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): sd_hijack.model_hijack.undo_hijack(sd_model) - load_vae(sd_model, vae_file) + load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) @@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("VAE Weights loaded.") + print("VAE weights loaded.") return sd_model diff --git a/modules/shared.py b/modules/shared.py index e0ec3136..9756adea 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -83,7 +83,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) -parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) +parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) @@ -383,7 +383,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index f04d9b7e..bd3087d4 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -125,24 +125,21 @@ def apply_upscale_latent_space(p, x, xs): def find_vae(name: str): - if name.lower() in ['auto', 'none']: - return name + if name.lower() in ['auto', 'automatic']: + return modules.sd_vae.unspecified + if name.lower() == 'none': + return None else: - vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE')) - found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True) - if found: - return found[0] + choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] + if len(choices) == 0: + print(f"No VAE found for {name}; using automatic") + return modules.sd_vae.unspecified else: - return 'auto' + return modules.sd_vae.vae_dict[choices[0]] def apply_vae(p, x, xs): - if x.lower().strip() == 'none': - modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None') - else: - found = find_vae(x) - if found: - v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found) + modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): @@ -271,7 +268,9 @@ class SharedSettingsStackHelper(object): def __exit__(self, exc_type, exc_value, tb): modules.sd_models.reload_model_weights(self.model) - modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae)) + + opts.data["sd_vae"] = self.vae + modules.sd_vae.reload_vae_weights(self.model) hypernetwork.load_hypernetwork(self.hypernetwork) hypernetwork.apply_strength() -- cgit v1.2.3 From f8c512478568293155539f616dce26c5e4495055 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 20:00:12 +0300 Subject: typo? --- modules/sd_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 6ea92711..add5cecf 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -95,7 +95,7 @@ def resolve_vae(checkpoint_file): return shared.cmd_opts.vae_path, 'from commandline argument' vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) - if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"): + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "Automatic"): return vae_near_checkpoint, 'found near the checkpoint' if shared.opts.sd_vae == "None": -- cgit v1.2.3 From 86359535d6fb0899fa9e838d27f2006b929331d5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 22:43:01 +0300 Subject: add buttons to copy images between img2img tabs --- javascript/ui.js | 21 +++++++++++++++++++-- modules/ui.py | 42 +++++++++++++++++++++++++++++++++++++++++- style.css | 18 ++++++++++++++++++ 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index 1e04a8f4..f8279124 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -45,10 +45,27 @@ function switch_to_txt2img(){ return args_to_array(arguments); } -function switch_to_img2img(){ +function switch_to_img2img_tab(no){ gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click(); - gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click(); + gradioApp().getElementById('mode_img2img').querySelectorAll('button')[no].click(); +} +function switch_to_img2img(){ + switch_to_img2img_tab(0); + return args_to_array(arguments); +} + +function switch_to_sketch(){ + switch_to_img2img_tab(1); + return args_to_array(arguments); +} + +function switch_to_inpaint(){ + switch_to_img2img_tab(2); + return args_to_array(arguments); +} +function switch_to_inpaint_sketch(){ + switch_to_img2img_tab(3); return args_to_array(arguments); } diff --git a/modules/ui.py b/modules/ui.py index 2625ae32..2425c66f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -795,19 +795,39 @@ def create_ui(): with FormRow().style(equal_height=False): with gr.Column(variant='panel', elem_id="img2img_settings"): + copy_image_buttons = [] + copy_image_destinations = {} + + def add_copy_image_controls(tab_name, elem): + with gr.Row(variant="compact", elem_id=f"img2img_copy_to_{tab_name}"): + gr.HTML("Copy image to: ", elem_id=f"img2img_label_copy_to_{tab_name}") + + for title, name in zip(['img2img', 'sketch', 'inpaint', 'inpaint sketch'], ['img2img', 'sketch', 'inpaint', 'inpaint_sketch']): + if name == tab_name: + gr.Button(title, interactive=False) + copy_image_destinations[name] = elem + continue + + button = gr.Button(title) + copy_image_buttons.append((button, name, elem)) + with gr.Tabs(elem_id="mode_img2img"): with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480) + add_copy_image_controls('img2img', init_img) with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: sketch = gr.Image(label="Image for img2img", elem_id="img2img_sketch", show_label=False, source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) + add_copy_image_controls('sketch', sketch) with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=480) + add_copy_image_controls('inpaint', init_img_with_mask) with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=480) inpaint_color_sketch_orig = gr.State(None) + add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) def update_orig(image, state): if image is not None: @@ -824,10 +844,29 @@ def create_ui(): with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") + gr.HTML(f"

Process images in a directory on the same machine where the server is running.
Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

") img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + def copy_image(img): + if isinstance(img, dict) and 'image' in img: + return img['image'] + + return img + + for button, name, elem in copy_image_buttons: + button.click( + fn=copy_image, + inputs=[elem], + outputs=[copy_image_destinations[name]], + ) + button.click( + fn=lambda: None, + _js="switch_to_"+name.replace(" ", "_"), + inputs=[], + outputs=[], + ) + with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") @@ -856,6 +895,7 @@ def create_ui(): outputs=[inpaint_controls, mask_alpha], ) + with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") diff --git a/style.css b/style.css index ffd6307f..2d484e06 100644 --- a/style.css +++ b/style.css @@ -676,6 +676,24 @@ footer { opacity: 0.5; } +#mode_img2img > div > div{ + gap: 0 !important; +} + +[id*='img2img_copy_to_'] { + border: none; +} + +[id*='img2img_copy_to_'] > button { +} + +[id*='img2img_label_copy_to_'] { + font-size: 1.0em; + font-weight: bold; + text-align: center; + line-height: 2.4em; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 2e172cf831a928223e93803b94896325bd4c22a7 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Sat, 14 Jan 2023 22:25:32 +0100 Subject: Only set TORCH_COMMAND if wasn't set webui-user --- webui.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index fcba6b7d..35542ed6 100755 --- a/webui.sh +++ b/webui.sh @@ -168,7 +168,10 @@ else gpu_info=$(lspci | grep VGA) if echo "$gpu_info" | grep -q "AMD" then - export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" + if [ -z ${TORCH_COMMAND+x} ] + then + export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" + fi HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" else exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" -- cgit v1.2.3 From ba077e2110cab891a46d14665fb161ce0669f31e Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Sat, 14 Jan 2023 23:19:52 +0100 Subject: Fix TORCH_COMMAND check --- webui.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index 35542ed6..6e07778f 100755 --- a/webui.sh +++ b/webui.sh @@ -168,7 +168,7 @@ else gpu_info=$(lspci | grep VGA) if echo "$gpu_info" | grep -q "AMD" then - if [ -z ${TORCH_COMMAND+x} ] + if [[ -z "${TORCH_COMMAND}" ]] then export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" fi -- cgit v1.2.3 From cbbdfc3609097fb8b31e32387396ee1ae299fc6f Mon Sep 17 00:00:00 2001 From: Josh R Date: Sat, 14 Jan 2023 14:45:08 -0800 Subject: Fix Aspect Ratio Overlay / AROverlay to work with Inpaint & Sketch --- javascript/aspectRatioOverlay.js | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index 66f26a22..0f164b82 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -21,11 +21,16 @@ function dimensionChange(e, is_width, is_height){ var targetElement = null; var tabIndex = get_tab_index('mode_img2img') - if(tabIndex == 0){ + if(tabIndex == 0){ // img2img targetElement = gradioApp().querySelector('div[data-testid=image] img'); - } else if(tabIndex == 1){ + } else if(tabIndex == 1){ //Sketch + targetElement = gradioApp().querySelector('#img2img_sketch div[data-testid=image] img'); + } else if(tabIndex == 2){ // Inpaint targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img'); + } else if(tabIndex == 3){ // Inpaint sketch + targetElement = gradioApp().querySelector('#inpaint_sketch div[data-testid=image] img'); } + if(targetElement){ -- cgit v1.2.3 From 9ef41df6f9043d58fbbeea1f06be8e5c8622248b Mon Sep 17 00:00:00 2001 From: Josh R Date: Sat, 14 Jan 2023 15:26:45 -0800 Subject: add inpaint masking controls to orderable section that the settings can order --- modules/shared.py | 1 + modules/ui.py | 58 +++++++++++++++++++++++++++---------------------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 51df056c..7ce8003f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -116,6 +116,7 @@ restricted_opts = { } ui_reorder_categories = [ + "masking", "sampler", "dimensions", "cfg", diff --git a/modules/ui.py b/modules/ui.py index 2425c66f..174930ab 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -867,35 +867,6 @@ def create_ui(): outputs=[], ) - with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: - with FormRow(): - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") - mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") - - with FormRow(): - inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") - - with FormRow(): - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") - - with FormRow(): - with gr.Column(): - inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") - - with gr.Column(scale=4): - inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") - - def select_img2img_tab(tab): - return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), - - for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]): - elem.select( - fn=lambda tab=i: select_img2img_tab(tab), - inputs=[], - outputs=[inpaint_controls, mask_alpha], - ) - - with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") @@ -937,6 +908,35 @@ def create_ui(): with FormGroup(elem_id="img2img_script_container"): custom_inputs = modules.scripts.scripts_img2img.setup_ui() + elif category == "masking": + with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: + with FormRow(): + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") + mask_alpha = gr.Slider(label="Mask transparency", visible=False, elem_id="img2img_mask_alpha") + + with FormRow(): + inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode") + + with FormRow(): + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index", elem_id="img2img_inpainting_fill") + + with FormRow(): + with gr.Column(): + inpaint_full_res = gr.Radio(label="Inpaint area", choices=["Whole picture", "Only masked"], type="index", value="Whole picture", elem_id="img2img_inpaint_full_res") + + with gr.Column(scale=4): + inpaint_full_res_padding = gr.Slider(label='Only masked padding, pixels', minimum=0, maximum=256, step=4, value=32, elem_id="img2img_inpaint_full_res_padding") + + def select_img2img_tab(tab): + return gr.update(visible=tab in [2, 3, 4]), gr.update(visible=tab == 3), + + for i, elem in enumerate([tab_img2img, tab_sketch, tab_inpaint, tab_inpaint_color, tab_inpaint_upload, tab_batch]): + elem.select( + fn=lambda tab=i: select_img2img_tab(tab), + inputs=[], + outputs=[inpaint_controls, mask_alpha], + ) + img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) -- cgit v1.2.3 From d97f467c0d27695d23edad5e4f8898a57e0ccb00 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 09:24:21 +0300 Subject: add license file --- LICENSE.txt | 663 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 663 insertions(+) create mode 100644 LICENSE.txt diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 00000000..14577543 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,663 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (c) 2023 AUTOMATIC1111 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. -- cgit v1.2.3 From eef1990a5e6c41ecb6943ff5529316ad5ededb2a Mon Sep 17 00:00:00 2001 From: brkirch Date: Sun, 15 Jan 2023 08:13:33 -0500 Subject: Fix Approx NN on devices other than CUDA --- modules/sd_vae_approx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 0a58542d..0027343a 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -36,7 +36,7 @@ def model(): if sd_vae_approx_model is None: sd_vae_approx_model = VAEApprox() - sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"))) + sd_vae_approx_model.load_state_dict(torch.load(os.path.join(paths.models_path, "VAE-approx", "model.pt"), map_location='cpu' if devices.device.type != 'cuda' else None)) sd_vae_approx_model.eval() sd_vae_approx_model.to(devices.device, devices.dtype) -- cgit v1.2.3 From f0312565e5b4d56a421af889a9a8eaea0ba92959 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sun, 15 Jan 2023 09:42:34 -0500 Subject: increase block size --- modules/hashes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/hashes.py b/modules/hashes.py index 14231771..b85a7580 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -34,9 +34,10 @@ def cache(subsection): def calculate_sha256(filename): hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 with open(filename, "rb") as f: - for chunk in iter(lambda: f.read(4096), b""): + for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) return hash_sha256.hexdigest() -- cgit v1.2.3 From d8b90ac121cbf0c18b1dc9d56a5e1d14ca51e74e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 18:50:56 +0300 Subject: big rework of progressbar/preview system to allow multiple users to prompts at the same time and do not get previews of each other --- javascript/progressbar.js | 249 ++++++++++++++++--------- javascript/textualInversion.js | 13 +- javascript/ui.js | 33 +++- modules/call_queue.py | 19 +- modules/hypernetworks/hypernetwork.py | 6 +- modules/img2img.py | 2 +- modules/progress.py | 96 ++++++++++ modules/sd_samplers.py | 2 +- modules/shared.py | 16 +- modules/textual_inversion/preprocess.py | 2 +- modules/textual_inversion/textual_inversion.py | 6 +- modules/txt2img.py | 2 +- modules/ui.py | 41 ++-- modules/ui_progress.py | 101 ---------- style.css | 74 +++++--- webui.py | 3 + 16 files changed, 390 insertions(+), 275 deletions(-) create mode 100644 modules/progress.py delete mode 100644 modules/ui_progress.py diff --git a/javascript/progressbar.js b/javascript/progressbar.js index d6323ed9..b7524ef7 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,82 +1,25 @@ // code related to showing and updating progressbar shown as the image is being made -global_progressbars = {} -galleries = {} -galleryObservers = {} - -// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running -timeoutIds = {} -function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ - // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id - // every time you use gr.HTML(elem_id='xxx'), so we handle this here - var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar) - var progressbarParent - if(progressbar){ - progressbarParent = gradioApp().querySelector("#"+id_progressbar) - } else{ - progressbar = gradioApp().getElementById(id_progressbar) - progressbarParent = null - } - var skip = id_skip ? gradioApp().getElementById(id_skip) : null - var interrupt = gradioApp().getElementById(id_interrupt) - - if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ - if(progressbar.innerText){ - let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion'; - if(document.title != newtitle){ - document.title = newtitle; - } - }else{ - let newtitle = 'Stable Diffusion' - if(document.title != newtitle){ - document.title = newtitle; - } - } - } - - if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ - global_progressbars[id_progressbar] = progressbar - - var mutationObserver = new MutationObserver(function(m){ - if(timeoutIds[id_part]) return; - - preview = gradioApp().getElementById(id_preview) - gallery = gradioApp().getElementById(id_gallery) +galleries = {} +storedGallerySelections = {} +galleryObservers = {} - if(preview != null && gallery != null){ - preview.style.width = gallery.clientWidth + "px" - preview.style.height = gallery.clientHeight + "px" - if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px" +function rememberGallerySelection(id_gallery){ + storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery) +} - //only watch gallery if there is a generation process going on - check_gallery(id_gallery); +function getGallerySelectedIndex(id_gallery){ + let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') + let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') - var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; - if(progressDiv){ - timeoutIds[id_part] = window.setTimeout(function() { - timeoutIds[id_part] = null - requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) - }, 500) - } else{ - if (skip) { - skip.style.display = "none" - } - interrupt.style.display = "none" + let currentlySelectedIndex = -1 + galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } }) - //disconnect observer once generation finished, so user can close selected image if they want - if (galleryObservers[id_gallery]) { - galleryObservers[id_gallery].disconnect(); - galleries[id_gallery] = null; - } - } - } - - }); - mutationObserver.observe( progressbar, { childList:true, subtree:true }) - } + return currentlySelectedIndex } +// this is a workaround for https://github.com/gradio-app/gradio/issues/2984 function check_gallery(id_gallery){ let gallery = gradioApp().getElementById(id_gallery) // if gallery has no change, no need to setting up observer again. @@ -85,10 +28,16 @@ function check_gallery(id_gallery){ if(galleryObservers[id_gallery]){ galleryObservers[id_gallery].disconnect(); } - let prevSelectedIndex = selected_gallery_index(); + + storedGallerySelections[id_gallery] = -1 + galleryObservers[id_gallery] = new MutationObserver(function (){ let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') + let currentlySelectedIndex = getGallerySelectedIndex(id_gallery) + prevSelectedIndex = storedGallerySelections[id_gallery] + storedGallerySelections[id_gallery] = -1 + if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { // automatically re-open previously selected index (if exists) activeElement = gradioApp().activeElement; @@ -120,30 +69,150 @@ function check_gallery(id_gallery){ } onUiUpdate(function(){ - check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') - check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') - check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery') + check_gallery('txt2img_gallery') + check_gallery('img2img_gallery') }) -function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){ - btn = gradioApp().getElementById(id_part+"_check_progress"); - if(btn==null) return; - - btn.click(); - var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; - var skip = id_skip ? gradioApp().getElementById(id_skip) : null - var interrupt = gradioApp().getElementById(id_interrupt) - if(progressDiv && interrupt){ - if (skip) { - skip.style.display = "block" +function request(url, data, handler, errorHandler){ + var xhr = new XMLHttpRequest(); + var url = url; + xhr.open("POST", url, true); + xhr.setRequestHeader("Content-Type", "application/json"); + xhr.onreadystatechange = function () { + if (xhr.readyState === 4) { + if (xhr.status === 200) { + var js = JSON.parse(xhr.responseText); + handler(js) + } else{ + errorHandler() + } } - interrupt.style.display = "block" + }; + var js = JSON.stringify(data); + xhr.send(js); +} + +function pad2(x){ + return x<10 ? '0'+x : x +} + +function formatTime(secs){ + if(secs > 3600){ + return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60) + } else if(secs > 60){ + return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60) + } else{ + return Math.floor(secs) + "s" } } -function requestProgress(id_part){ - btn = gradioApp().getElementById(id_part+"_check_progress_initial"); - if(btn==null) return; +function randomId(){ + return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" +} + +// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and +// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd. +// calls onProgress every time there is a progress update +function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){ + var dateStart = new Date() + var wasEverActive = false + var parentProgressbar = progressbarContainer.parentNode + var parentGallery = gallery.parentNode + + var divProgress = document.createElement('div') + divProgress.className='progressDiv' + var divInner = document.createElement('div') + divInner.className='progress' + + divProgress.appendChild(divInner) + parentProgressbar.insertBefore(divProgress, progressbarContainer) + + var livePreview = document.createElement('div') + livePreview.className='livePreview' + parentGallery.insertBefore(livePreview, gallery) + + var removeProgressBar = function(){ + parentProgressbar.removeChild(divProgress) + parentGallery.removeChild(livePreview) + atEnd() + } + + var fun = function(id_task, id_live_preview){ + request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){ + console.log(res) + + if(res.completed){ + removeProgressBar() + return + } + + var rect = progressbarContainer.getBoundingClientRect() + + if(rect.width){ + divProgress.style.width = rect.width + "px"; + } + + progressText = "" + + divInner.style.width = ((res.progress || 0) * 100.0) + '%' + + if(res.progress > 0){ + progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%' + } + + if(res.eta){ + progressText += " ETA: " + formatTime(res.eta) + } else if(res.textinfo){ + progressText += " " + res.textinfo + } + + divInner.textContent = progressText + + var elapsedFromStart = (new Date() - dateStart) / 1000 + + if(res.active) wasEverActive = true; + + if(! res.active && wasEverActive){ + removeProgressBar() + return + } + + if(elapsedFromStart > 5 && !res.queued && !res.active){ + removeProgressBar() + return + } + + + if(res.live_preview){ + var img = new Image(); + img.onload = function() { + var rect = gallery.getBoundingClientRect() + if(rect.width){ + livePreview.style.width = rect.width + "px" + livePreview.style.height = rect.height + "px" + } + + livePreview.innerHTML = '' + livePreview.appendChild(img) + if(livePreview.childElementCount > 2){ + livePreview.removeChild(livePreview.firstElementChild) + } + } + img.src = res.live_preview; + } + + + if(onProgress){ + onProgress(res) + } + + setTimeout(() => { + fun(id_task, res.id_live_preview); + }, 500) + }, function(){ + removeProgressBar() + }) + } - btn.click(); + fun(id_task, 0) } diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js index 8061be08..0354b860 100644 --- a/javascript/textualInversion.js +++ b/javascript/textualInversion.js @@ -1,8 +1,17 @@ + function start_training_textual_inversion(){ - requestProgress('ti') gradioApp().querySelector('#ti_error').innerHTML='' - return args_to_array(arguments) + var id = randomId() + requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){ + gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo + }) + + var res = args_to_array(arguments) + + res[0] = id + + return res } diff --git a/javascript/ui.js b/javascript/ui.js index f8279124..ecf97cb3 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -126,18 +126,41 @@ function create_submit_args(args){ return res } +function showSubmitButtons(tabname, show){ + gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block" + gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block" +} + function submit(){ - requestProgress('txt2img') + rememberGallerySelection('txt2img_gallery') + showSubmitButtons('txt2img', false) + + var id = randomId() + requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){ + showSubmitButtons('txt2img', true) + + }) - return create_submit_args(arguments) + var res = create_submit_args(arguments) + + res[0] = id + + return res } function submit_img2img(){ - requestProgress('img2img') + rememberGallerySelection('img2img_gallery') + showSubmitButtons('img2img', false) + + var id = randomId() + requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ + showSubmitButtons('img2img', true) + }) - res = create_submit_args(arguments) + var res = create_submit_args(arguments) - res[0] = get_tab_index('mode_img2img') + res[0] = id + res[1] = get_tab_index('mode_img2img') return res } diff --git a/modules/call_queue.py b/modules/call_queue.py index 4cd49533..92097c15 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -4,7 +4,7 @@ import threading import traceback import time -from modules import shared +from modules import shared, progress queue_lock = threading.Lock() @@ -22,12 +22,23 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): - shared.state.begin() + # if the first argument is a string that says "task(...)", it is treated as a job id + if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": + id_task = args[0] + progress.add_task_to_queue(id_task) + else: + id_task = None with queue_lock: - res = func(*args, **kwargs) + shared.state.begin() + progress.start_task(id_task) + + try: + res = func(*args, **kwargs) + finally: + progress.finish_task(id_task) - shared.state.end() + shared.state.end() return res diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3aebefa8..ae6af516 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -453,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -629,7 +629,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" pbar.set_description(description) - shared.state.textinfo = description if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' @@ -701,7 +700,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, torch.cuda.set_rng_state_all(cuda_rng_state) hypernetwork.train() if image is not None: - shared.state.current_image = image + shared.state.assign_current_image(image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" diff --git a/modules/img2img.py b/modules/img2img.py index f62783c6..f4a03c57 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_batch = mode == 5 if mode == 0: # img2img diff --git a/modules/progress.py b/modules/progress.py new file mode 100644 index 00000000..3327b883 --- /dev/null +++ b/modules/progress.py @@ -0,0 +1,96 @@ +import base64 +import io +import time + +import gradio as gr +from pydantic import BaseModel, Field + +from modules.shared import opts + +import modules.shared as shared + + +current_task = None +pending_tasks = {} +finished_tasks = [] + + +def start_task(id_task): + global current_task + + current_task = id_task + pending_tasks.pop(id_task, None) + + +def finish_task(id_task): + global current_task + + if current_task == id_task: + current_task = None + + finished_tasks.append(id_task) + if len(finished_tasks) > 16: + finished_tasks.pop(0) + + +def add_task_to_queue(id_job): + pending_tasks[id_job] = time.time() + + +class ProgressRequest(BaseModel): + id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") + id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image") + + +class ProgressResponse(BaseModel): + active: bool = Field(title="Whether the task is being worked on right now") + queued: bool = Field(title="Whether the task is in queue") + completed: bool = Field(title="Whether the task has already finished") + progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") + eta: float = Field(default=None, title="ETA in secs") + live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri") + id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image") + textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") + + +def setup_progress_api(app): + return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) + + +def progressapi(req: ProgressRequest): + active = req.id_task == current_task + queued = req.id_task in pending_tasks + completed = req.id_task in finished_tasks + + if not active: + return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...") + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + progress = min(progress, 1) + + elapsed_since_start = time.time() - shared.state.time_start + predicted_duration = elapsed_since_start / progress if progress > 0 else None + eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None + + id_live_preview = req.id_live_preview + shared.state.set_current_image() + if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview: + image = shared.state.current_image + if image is not None: + buffered = io.BytesIO() + image.save(buffered, format="png") + live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii") + id_live_preview = shared.state.id_live_preview + else: + live_preview = None + else: + live_preview = None + + return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) + diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 7616fded..76e0e0d5 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -140,7 +140,7 @@ def store_latent(decoded): if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if not shared.parallel_processing_allowed: - shared.state.current_image = sample_to_image(decoded) + shared.state.assign_current_image(sample_to_image(decoded)) class InterruptedException(BaseException): diff --git a/modules/shared.py b/modules/shared.py index 51df056c..de99aca9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -152,6 +152,7 @@ def reload_hypernetworks(): hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + class State: skipped = False interrupted = False @@ -165,6 +166,7 @@ class State: current_latent = None current_image = None current_image_sampling_step = 0 + id_live_preview = 0 textinfo = None time_start = None need_restart = False @@ -207,6 +209,7 @@ class State: self.current_latent = None self.current_image = None self.current_image_sampling_step = 0 + self.id_live_preview = 0 self.skipped = False self.interrupted = False self.textinfo = None @@ -220,8 +223,8 @@ class State: devices.torch_gc() - """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" def set_current_image(self): + """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" if not parallel_processing_allowed: return @@ -234,12 +237,16 @@ class State: import modules.sd_samplers if opts.show_progress_grid: - self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent) + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) else: - self.current_image = modules.sd_samplers.sample_to_image(self.current_latent) + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) self.current_image_sampling_step = self.sampling_step + def assign_current_image(self, image): + self.current_image = image + self.id_live_preview += 1 + state = State() state.server_start = time.time() @@ -424,8 +431,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), })) options_templates.update(options_section(('ui', "User interface"), { - "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), @@ -446,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "Live previews"), { "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), + "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 3c1042ad..64abff4d 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop -def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): +def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): try: if process_caption: shared.interrogator.load() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 63935878..7e4a6d24 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -345,7 +345,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None) @@ -510,7 +510,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" pbar.set_description(description) - shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' @@ -560,7 +559,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ shared.sd_model.first_stage_model.to(devices.cpu) if image is not None: - shared.state.current_image = image + shared.state.assign_current_image(image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" diff --git a/modules/txt2img.py b/modules/txt2img.py index 38b5f591..ca5d4550 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, diff --git a/modules/ui.py b/modules/ui.py index 2425c66f..ff33236b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -356,7 +356,7 @@ def create_toprow(is_img2img): button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") with gr.Column(scale=1): - with gr.Row(): + with gr.Row(elem_id=f"{id_part}_generate_box"): skip = gr.Button('Skip', elem_id=f"{id_part}_skip") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') @@ -384,9 +384,7 @@ def create_toprow(is_img2img): def setup_progressbar(*args, **kwargs): - import modules.ui_progress - - modules.ui_progress.setup_progressbar(*args, **kwargs) + pass def apply_setting(key, value): @@ -479,8 +477,8 @@ Requested path was: {f} else: sp.Popen(["xdg-open", path]) - with gr.Column(variant='panel'): - with gr.Group(): + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) generation_info = None @@ -595,15 +593,6 @@ def create_ui(): dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - with gr.Row().style(equal_height=False): with gr.Column(variant='panel', elem_id="txt2img_settings"): for category in ordered_ui_categories(): @@ -682,6 +671,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", inputs=[ + dummy_component, txt2img_prompt, txt2img_negative_prompt, txt2img_prompt_style, @@ -782,16 +772,7 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) with FormRow().style(equal_height=False): with gr.Column(variant='panel', elem_id="img2img_settings"): @@ -958,6 +939,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", inputs=[ + dummy_component, dummy_component, img2img_prompt, img2img_negative_prompt, @@ -1335,15 +1317,11 @@ def create_ui(): script_callbacks.ui_train_tabs_callback(params) - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") + with gr.Column(elem_id='ti_gallery_container'): ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) ti_progress = gr.HTML(elem_id="ti_progress", value="") ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) create_embedding.click( fn=modules.textual_inversion.ui.create_embedding, @@ -1384,6 +1362,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, process_src, process_dst, process_width, @@ -1411,6 +1390,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, train_embedding_name, embedding_learn_rate, batch_size, @@ -1443,6 +1423,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, train_hypernetwork_name, hypernetwork_learn_rate, batch_size, diff --git a/modules/ui_progress.py b/modules/ui_progress.py deleted file mode 100644 index 7cd312e4..00000000 --- a/modules/ui_progress.py +++ /dev/null @@ -1,101 +0,0 @@ -import time - -import gradio as gr - -from modules.shared import opts - -import modules.shared as shared - - -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr.update(visible=False) - preview_visibility = gr.update(visible=False) - - if opts.live_previews_enable: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr.update(visible=True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr.update(visible=False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) diff --git a/style.css b/style.css index 2d484e06..786b71d1 100644 --- a/style.css +++ b/style.css @@ -305,26 +305,42 @@ input[type="range"]{ } .progressDiv{ - width: 100%; - height: 20px; - background: #b4c0cc; - border-radius: 8px; + position: absolute; + height: 20px; + top: -20px; + background: #b4c0cc; + border-radius: 8px !important; } .dark .progressDiv{ - background: #424c5b; + background: #424c5b; } .progressDiv .progress{ - width: 0%; - height: 20px; - background: #0060df; - color: white; - font-weight: bold; - line-height: 20px; - padding: 0 8px 0 0; - text-align: right; - border-radius: 8px; + width: 0%; + height: 20px; + background: #0060df; + color: white; + font-weight: bold; + line-height: 20px; + padding: 0 8px 0 0; + text-align: right; + border-radius: 8px; + overflow: visible; + white-space: nowrap; +} + +.livePreview{ + position: absolute; + z-index: 300; + background-color: white; + margin: -4px; +} + +.livePreview img{ + object-fit: contain; + width: 100%; + height: 100%; } #lightboxModal{ @@ -450,23 +466,25 @@ input[type="range"]{ display:none } -#txt2img_interrupt, #img2img_interrupt{ - position: absolute; - width: 50%; - height: 72px; - background: #b4c0cc; - border-radius: 0px; - display: none; +#txt2img_generate_box, #img2img_generate_box{ + position: relative; +} + +#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{ + position: absolute; + width: 50%; + height: 100%; + background: #b4c0cc; + display: none; } +#txt2img_interrupt, #img2img_interrupt{ + right: 0; + border-radius: 0 0.5rem 0.5rem 0; +} #txt2img_skip, #img2img_skip{ - position: absolute; - width: 50%; - right: 0px; - height: 72px; - background: #b4c0cc; - border-radius: 0px; - display: none; + left: 0; + border-radius: 0.5rem 0 0 0.5rem; } .red { diff --git a/webui.py b/webui.py index 1fff80da..4624fe18 100644 --- a/webui.py +++ b/webui.py @@ -34,6 +34,7 @@ import modules.sd_vae import modules.txt2img import modules.script_callbacks import modules.textual_inversion.textual_inversion +import modules.progress import modules.ui from modules import modelloader @@ -181,6 +182,8 @@ def webui(): app.add_middleware(GZipMiddleware, minimum_size=1000) + modules.progress.setup_progress_api(app) + if launch_api: create_api(app) -- cgit v1.2.3 From 388708f7b13dfbc890135cad678bfbcebd7baf37 Mon Sep 17 00:00:00 2001 From: pangbo13 <373108669@qq.com> Date: Mon, 16 Jan 2023 00:56:24 +0800 Subject: fix when show_progress_every_n_steps == -1 --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index de99aca9..f857ccde 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -228,7 +228,7 @@ class State: if not parallel_processing_allowed: return - if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable: + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1: self.do_set_current_image() def do_set_current_image(self): -- cgit v1.2.3 From 16f410893eb96c7810cbbd812541ba35e0e92524 Mon Sep 17 00:00:00 2001 From: AngelBottomless <35677394+aria1th@users.noreply.github.com> Date: Mon, 16 Jan 2023 02:08:47 +0900 Subject: fix missing 'mean loss' for tensorboard integration --- modules/hypernetworks/hypernetwork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index ae6af516..bbd1f673 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -644,7 +644,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi if shared.opts.training_enable_tensorboard: epoch_num = hypernetwork.step // len(ds) epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 - + mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, { -- cgit v1.2.3 From f6aac4c65a681383616f6e72e3865002600f476f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 20:20:29 +0300 Subject: eliminate flicker for live previews --- javascript/progressbar.js | 14 +++++++------- style.css | 1 + 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index b7524ef7..8f22c018 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -184,15 +184,15 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre if(res.live_preview){ + + var rect = gallery.getBoundingClientRect() + if(rect.width){ + livePreview.style.width = rect.width + "px" + livePreview.style.height = rect.height + "px" + } + var img = new Image(); img.onload = function() { - var rect = gallery.getBoundingClientRect() - if(rect.width){ - livePreview.style.width = rect.width + "px" - livePreview.style.height = rect.height + "px" - } - - livePreview.innerHTML = '' livePreview.appendChild(img) if(livePreview.childElementCount > 2){ livePreview.removeChild(livePreview.firstElementChild) diff --git a/style.css b/style.css index 786b71d1..5bf1c6f9 100644 --- a/style.css +++ b/style.css @@ -338,6 +338,7 @@ input[type="range"]{ } .livePreview img{ + position: absolute; object-fit: contain; width: 100%; height: 100%; -- cgit v1.2.3 From a534bdfc801e0c83e378dfaa2d04cf865d7109f9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 20:27:39 +0300 Subject: add setting for progressbar update period --- javascript/progressbar.js | 2 +- modules/shared.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 8f22c018..59173c83 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -208,7 +208,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre setTimeout(() => { fun(id_task, res.id_live_preview); - }, 500) + }, opts.live_preview_refresh_period || 500) }, function(){ removeProgressBar() }) diff --git a/modules/shared.py b/modules/shared.py index de99aca9..3483db1c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -455,6 +455,7 @@ options_templates.update(options_section(('ui', "Live previews"), { "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), + "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds") })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From b6ce041cdf722b400df9b5eac306d0cb049923d7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 20:29:48 +0300 Subject: put interrupt and skip buttons back where they were --- modules/ui.py | 2 +- style.css | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index ff33236b..7a357f9a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -357,8 +357,8 @@ def create_toprow(is_img2img): with gr.Column(scale=1): with gr.Row(elem_id=f"{id_part}_generate_box"): - skip = gr.Button('Skip', elem_id=f"{id_part}_skip") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') skip.click( diff --git a/style.css b/style.css index 5bf1c6f9..750fe315 100644 --- a/style.css +++ b/style.css @@ -480,13 +480,13 @@ input[type="range"]{ } #txt2img_interrupt, #img2img_interrupt{ - right: 0; - border-radius: 0 0.5rem 0.5rem 0; -} -#txt2img_skip, #img2img_skip{ left: 0; border-radius: 0.5rem 0 0 0.5rem; } +#txt2img_skip, #img2img_skip{ + right: 0; + border-radius: 0 0.5rem 0.5rem 0; +} .red { color: red; -- cgit v1.2.3 From 110d1a2d598bcfacffe3d524df1a3422b4cbd8ec Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sun, 15 Jan 2023 12:41:00 -0500 Subject: add fields to settings file --- modules/textual_inversion/logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py index 31e50b64..734a4b6f 100644 --- a/modules/textual_inversion/logging.py +++ b/modules/textual_inversion/logging.py @@ -2,7 +2,7 @@ import datetime import json import os -saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"} +saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "clip_grad_mode", "clip_grad_value", "gradient_step", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file", "gradient_step", "latent_sampling_method"} saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"} saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"} saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet -- cgit v1.2.3 From 598f7fcd84f655dd204ad5e258dc1c41cc806cde Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Mon, 16 Jan 2023 02:46:21 +0900 Subject: Fix loss_dict problem --- modules/hypernetworks/hypernetwork.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index bbd1f673..438e3e9f 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -561,6 +561,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi _loss_step = 0 #internal # size = len(ds.indexes) # loss_dict = defaultdict(lambda : deque(maxlen = 1024)) + loss_logging = [] # losses = torch.zeros((size,)) # previous_mean_losses = [0] # previous_mean_loss = 0 @@ -601,6 +602,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi else: c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) loss = shared.sd_model(x, c)[0] / gradient_step + loss_logging.append(loss.item()) del x del c @@ -644,7 +646,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi if shared.opts.training_enable_tensorboard: epoch_num = hypernetwork.step // len(ds) epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1 - mean_loss = sum(sum(x) for x in loss_dict.values()) / sum(len(x) for x in loss_dict.values()) + mean_loss = sum(loss_logging) / len(loss_logging) textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, { -- cgit v1.2.3 From 13445738d974edcca5ff2f4f8f3833c1f3433e5e Mon Sep 17 00:00:00 2001 From: aria1th <35677394+aria1th@users.noreply.github.com> Date: Mon, 16 Jan 2023 03:02:54 +0900 Subject: Fix tensorboard related functions --- modules/hypernetworks/hypernetwork.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 438e3e9f..c963fc40 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -561,7 +561,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi _loss_step = 0 #internal # size = len(ds.indexes) # loss_dict = defaultdict(lambda : deque(maxlen = 1024)) - loss_logging = [] + loss_logging = deque(maxlen=len(ds) * 3) # this should be configurable parameter, this is 3 * epoch(dataset size) # losses = torch.zeros((size,)) # previous_mean_losses = [0] # previous_mean_loss = 0 @@ -602,7 +602,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi else: c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory) loss = shared.sd_model(x, c)[0] / gradient_step - loss_logging.append(loss.item()) del x del c @@ -612,7 +611,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue - + loss_logging.append(_loss_step) if clip_grad: clip_grad(weights, clip_grad_sched.learn_rate) @@ -690,9 +689,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi processed = processing.process_images(p) image = processed.images[0] if len(processed.images) > 0 else None - - if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: - textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) @@ -703,7 +699,10 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi hypernetwork.train() if image is not None: shared.state.assign_current_image(image) - + if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images: + textual_inversion.tensorboard_add_image(tensorboard_writer, + f"Validation at epoch {epoch_num}", image, + hypernetwork.step) last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" -- cgit v1.2.3 From 8e2aeee4a127b295bfc880800e4a312e0f049b85 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 22:29:53 +0300 Subject: add BREAK keyword to end current text chunk and start the next --- modules/prompt_parser.py | 7 ++++++- modules/sd_hijack_clip.py | 17 +++++++++++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 870218db..69665372 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -274,6 +274,7 @@ re_attention = re.compile(r""" : """, re.X) +re_break = re.compile(r"\s*\bBREAK\b\s*", re.S) def parse_prompt_attention(text): """ @@ -339,7 +340,11 @@ def parse_prompt_attention(text): elif text == ']' and len(square_brackets) > 0: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: - res.append([text, 1.0]) + parts = re.split(re_break, text) + for i, part in enumerate(parts): + if i > 0: + res.append(["BREAK", -1]) + res.append([part, 1.0]) for pos in round_brackets: multiply_range(pos, round_bracket_multiplier) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 852afc66..9fa5c5c5 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -96,13 +96,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): token_count = 0 last_comma = -1 - def next_chunk(): - """puts current chunk into the list of results and produces the next one - empty""" + def next_chunk(is_last=False): + """puts current chunk into the list of results and produces the next one - empty; + if is_last is true, tokens tokens at the end won't add to token_count""" nonlocal token_count nonlocal last_comma nonlocal chunk - token_count += len(chunk.tokens) + if is_last: + token_count += len(chunk.tokens) + else: + token_count += self.chunk_length + to_add = self.chunk_length - len(chunk.tokens) if to_add > 0: chunk.tokens += [self.id_end] * to_add @@ -116,6 +121,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk = PromptChunk() for tokens, (text, weight) in zip(tokenized, parsed): + if text == 'BREAK' and weight == -1: + next_chunk() + continue + position = 0 while position < len(tokens): token = tokens[position] @@ -159,7 +168,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): position += embedding_length_in_tokens if len(chunk.tokens) > 0 or len(chunks) == 0: - next_chunk() + next_chunk(is_last=True) return chunks, token_count -- cgit v1.2.3 From db9b11617997ad02e5eb68be306078b3b8d3e2cf Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sun, 15 Jan 2023 23:13:58 +0300 Subject: fix paths with parentheses --- webui.bat | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.bat b/webui.bat index e6a7a429..3165b94d 100644 --- a/webui.bat +++ b/webui.bat @@ -1,7 +1,7 @@ @echo off if not defined PYTHON (set PYTHON=python) -if not defined VENV_DIR (set VENV_DIR=%~dp0%venv) +if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") set ERROR_REPORTING=FALSE -- cgit v1.2.3 From fc25af3939f0366f147892a8ae5b9da56495352b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 23:22:51 +0300 Subject: remove unneeded log from progressbar --- javascript/progressbar.js | 2 -- 1 file changed, 2 deletions(-) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 59173c83..5072c13f 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -139,8 +139,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre var fun = function(id_task, id_live_preview){ request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){ - console.log(res) - if(res.completed){ removeProgressBar() return -- cgit v1.2.3 From 89314e79da21ac71ad3133ccf5ac3e85d4c24052 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 23:23:16 +0300 Subject: fix an error that happens when you send an empty image from txt2img to img2img --- modules/generation_parameters_copypaste.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 593d99ef..a381ff59 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -37,6 +37,9 @@ def quote(text): def image_from_url_text(filedata): + if filedata is None: + return None + if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): filedata = filedata[0] -- cgit v1.2.3 From 3db22e6ee45193559a2c3ba44ab672b067245f99 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 23:32:38 +0300 Subject: rename masking to inpaint in UI make inpaint go to the right place for users who don't have it in config string --- modules/shared.py | 2 +- modules/ui.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 3bdc375b..f06ae610 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -116,7 +116,7 @@ restricted_opts = { } ui_reorder_categories = [ - "masking", + "inpaint", "sampler", "dimensions", "cfg", diff --git a/modules/ui.py b/modules/ui.py index b3d4af3e..20b66165 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -570,9 +570,9 @@ def create_sampler_and_steps_selection(choices, tabname): def ordered_ui_categories(): - user_order = {x.strip(): i for i, x in enumerate(shared.opts.ui_reorder.split(","))} + user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))} - for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] + 1000)): + for i, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)): yield category @@ -889,7 +889,7 @@ def create_ui(): with FormGroup(elem_id="img2img_script_container"): custom_inputs = modules.scripts.scripts_img2img.setup_ui() - elif category == "masking": + elif category == "inpaint": with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: with FormRow(): mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id="img2img_mask_blur") -- cgit v1.2.3 From 9a43acf94ead6bc15da2782c39ab5a3107c3f06c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 23:37:34 +0300 Subject: add background color for live previews in dark mode --- style.css | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/style.css b/style.css index 750fe315..78fa9838 100644 --- a/style.css +++ b/style.css @@ -337,6 +337,10 @@ input[type="range"]{ margin: -4px; } +.dark .livePreview{ + background-color: rgb(17 24 39 / var(--tw-bg-opacity)); +} + .livePreview img{ position: absolute; object-fit: contain; -- cgit v1.2.3 From 3f887f7f61d69fa699a272166b79fdb787e9ce1d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 00:44:46 +0300 Subject: support old configs that say "auto" for ssd_vae change sd_vae_as_default to True by default as it's a more sensible setting --- modules/sd_vae.py | 6 ++++-- modules/shared.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index add5cecf..e9c6bb40 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -94,8 +94,10 @@ def resolve_vae(checkpoint_file): if shared.cmd_opts.vae_path is not None: return shared.cmd_opts.vae_path, 'from commandline argument' + is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) - if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "Automatic"): + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): return vae_near_checkpoint, 'found near the checkpoint' if shared.opts.sd_vae == "None": @@ -105,7 +107,7 @@ def resolve_vae(checkpoint_file): if vae_from_options is not None: return vae_from_options, 'specified in settings' - if shared.opts.sd_vae != "Automatic": + if is_automatic: print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") return None, None diff --git a/modules/shared.py b/modules/shared.py index f06ae610..c5fc250e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -394,7 +394,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), - "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), + "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -- cgit v1.2.3 From ff6a5bcec1ce25aa8f08b157ea957d764be23d8d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 01:28:20 +0300 Subject: bugfix for previous commit --- modules/sd_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e9c6bb40..b2af2ce7 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -107,7 +107,7 @@ def resolve_vae(checkpoint_file): if vae_from_options is not None: return vae_from_options, 'specified in settings' - if is_automatic: + if not is_automatic: print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") return None, None -- cgit v1.2.3 From f202ff1901c27d1f82d5e2684dba9e1ed24ffdf2 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sun, 15 Jan 2023 19:43:34 -0800 Subject: Make XY grid cancellation much faster --- scripts/xy_grid.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index bd3087d4..13a3a046 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -406,6 +406,9 @@ class Script(scripts.Script): grid_infotext = [None] def cell(x, y): + if shared.state.interrupted: + return Processed(p, [], p.seed, "") + pc = copy(p) x_opt.apply(pc, x, xs) y_opt.apply(pc, y, ys) -- cgit v1.2.3 From 029260b4ca7267d7a75319dbc11bca2a8c52774e Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sun, 15 Jan 2023 21:40:57 -0800 Subject: Optimize XY grid to run slower axes fewer times --- scripts/xy_grid.py | 123 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 70 insertions(+), 53 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 13a3a046..074ee919 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -175,76 +175,87 @@ def str_permutations(x): """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" return x -AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"]) -AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"]) +AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"]) +AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"]) axis_options = [ - AxisOption("Nothing", str, do_nothing, format_nothing, None), - AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None), - AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None), - AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None), - AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None), - AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None), - AxisOption("Prompt S/R", str, apply_prompt, format_value, None), - AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None), - AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers), - AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks), - AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None), - AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None), - AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None), - AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None), - AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None), - AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), - AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), - AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None), - AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), - AxisOption("VAE", str, apply_vae, format_value_add_label, None), - AxisOption("Styles", str, apply_styles, format_value_add_label, None), + AxisOption("Nothing", str, do_nothing, format_nothing, None, 0), + AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0), + AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0), + AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0), + AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0), + AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0), + AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0), + AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0), + AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2), + AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0), + AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0), + AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0), + AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0), + AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0), + AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0), + AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0), + AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0), + AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0), + AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0), + AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7), + AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0), ] -def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images): +def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] # Temporary list of all the images that are generated to be populated into the grid. # Will be filled with empty images for any individual step that fails to process properly - image_cache = [] + image_cache = [None] * (len(xs) * len(ys)) processed_result = None cell_mode = "P" - cell_size = (1,1) + cell_size = (1, 1) state.job_count = len(xs) * len(ys) * p.n_iter - for iy, y in enumerate(ys): + def process_cell(x, y, ix, iy): + nonlocal image_cache, processed_result, cell_mode, cell_size + + state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" + + processed: Processed = cell(x, y) + + try: + # this dereference will throw an exception if the image was not processed + # (this happens in cases such as if the user stops the process from the UI) + processed_image = processed.images[0] + + if processed_result is None: + # Use our first valid processed result as a template container to hold our full results + processed_result = copy(processed) + cell_mode = processed_image.mode + cell_size = processed_image.size + processed_result.images = [Image.new(cell_mode, cell_size)] + + image_cache[ix + iy * len(xs)] = processed_image + if include_lone_images: + processed_result.images.append(processed_image) + processed_result.all_prompts.append(processed.prompt) + processed_result.all_seeds.append(processed.seed) + processed_result.infotexts.append(processed.infotexts[0]) + except: + image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size) + + if swap_axes_processing_order: for ix, x in enumerate(xs): - state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" - - processed:Processed = cell(x, y) - try: - # this dereference will throw an exception if the image was not processed - # (this happens in cases such as if the user stops the process from the UI) - processed_image = processed.images[0] - - if processed_result is None: - # Use our first valid processed result as a template container to hold our full results - processed_result = copy(processed) - cell_mode = processed_image.mode - cell_size = processed_image.size - processed_result.images = [Image.new(cell_mode, cell_size)] - - image_cache.append(processed_image) - if include_lone_images: - processed_result.images.append(processed_image) - processed_result.all_prompts.append(processed.prompt) - processed_result.all_seeds.append(processed.seed) - processed_result.infotexts.append(processed.infotexts[0]) - except: - image_cache.append(Image.new(cell_mode, cell_size)) + for iy, y in enumerate(ys): + process_cell(x, y, ix, iy) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, ix, iy) if not processed_result: print("Unexpected error: draw_xy_grid failed to return even a single processed image") @@ -405,6 +416,11 @@ class Script(scripts.Script): grid_infotext = [None] + # If one of the axes is very slow to change between (like SD model + # checkpoint), then make sure it is in the outer iteration of the nested + # `for` loop. + swap_axes_processing_order = x_opt.cost > y_opt.cost + def cell(x, y): if shared.state.interrupted: return Processed(p, [], p.seed, "") @@ -443,7 +459,8 @@ class Script(scripts.Script): y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], cell=cell, draw_legend=draw_legend, - include_lone_images=include_lone_images + include_lone_images=include_lone_images, + swap_axes_processing_order=swap_axes_processing_order ) if opts.grid_save: -- cgit v1.2.3 From 2144c2eb7f5842caed1227d4ec7e659c79a84ce9 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sun, 15 Jan 2023 21:41:58 -0800 Subject: Add swap axes button for XY Grid --- scripts/xy_grid.py | 26 ++++++++++++++++++++------ style.css | 10 ++++++++++ 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 13a3a046..99a660c1 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -23,6 +23,9 @@ import os import re +up_down_arrow_symbol = "\u2195\ufe0f" + + def apply_field(field): def fun(p, x, xs): setattr(p, field, x) @@ -293,17 +296,28 @@ class Script(scripts.Script): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) - x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) - - with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) - y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) + with gr.Column(scale=1, elem_id="xy_grid_button_column"): + swap_axes_button = gr.Button(value=up_down_arrow_symbol, elem_id="xy_grid_swap_axes") + with gr.Column(scale=19): + with gr.Row(): + x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) + + with gr.Row(): + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + def swap_axes(x_type, x_values, y_type, y_values): + nonlocal current_axis_options + return current_axis_options[y_type].label, y_values, current_axis_options[x_type].label, x_values + + swap_args = [x_type, x_values, y_type, y_values] + swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) + return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): diff --git a/style.css b/style.css index 78fa9838..1fddfcc2 100644 --- a/style.css +++ b/style.css @@ -717,6 +717,16 @@ footer { line-height: 2.4em; } +#xy_grid_button_column { + min-width: 38px !important; +} + +#xy_grid_button_column button { + height: 100%; + margin-bottom: 0.7em; + margin-left: 1em; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 972f5785073b8ba5957add72debd74fc56ee9329 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 09:27:52 +0300 Subject: fix problems related to checkpoint/VAE switching in XY plot --- scripts/xy_grid.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 13a3a046..0cdfa952 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -263,14 +263,12 @@ class SharedSettingsStackHelper(object): def __enter__(self): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.hypernetwork = opts.sd_hypernetwork - self.model = shared.sd_model self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): - modules.sd_models.reload_model_weights(self.model) - opts.data["sd_vae"] = self.vae - modules.sd_vae.reload_vae_weights(self.model) + modules.sd_models.reload_model_weights() + modules.sd_vae.reload_vae_weights() hypernetwork.load_hypernetwork(self.hypernetwork) hypernetwork.apply_strength() -- cgit v1.2.3 From 064983c0adb00cd9e88d2f06f66c9a1d5bc116c3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 12:56:30 +0300 Subject: return an option to hide progressbar --- javascript/progressbar.js | 1 + modules/shared.py | 1 + 2 files changed, 2 insertions(+) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 5072c13f..da6709bc 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -121,6 +121,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre var divProgress = document.createElement('div') divProgress.className='progressDiv' + divProgress.style.display = opts.show_progressbar ? "" : "none" var divInner = document.createElement('div') divInner.className='progress' diff --git a/modules/shared.py b/modules/shared.py index c5fc250e..483c4c62 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -451,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), { })) options_templates.update(options_section(('ui', "Live previews"), { + "show_progressbar": OptionInfo(True, "Show progressbar"), "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), -- cgit v1.2.3 From 55947857f035040d00249f02b17e39370033a99b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 17:36:56 +0300 Subject: add a button for XY Plot to fill in available values for axes that support this --- javascript/hints.js | 1 + scripts/xy_grid.py | 101 ++++++++++++++++++++++++++++++++++------------------ style.css | 12 +------ 3 files changed, 68 insertions(+), 46 deletions(-) diff --git a/javascript/hints.js b/javascript/hints.js index 244bfde2..fa5e5ae8 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -20,6 +20,7 @@ titles = { "\u{1f4be}": "Save style", "\U0001F5D1": "Clear prompt", "\u{1f4cb}": "Apply selected styles to current prompt", + "\u{1f4d2}": "Paste available values into the field", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index e06c11cb..bf4ba92f 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images, paths, sd_samplers, processing +from modules import images, paths, sd_samplers, processing, sd_models, sd_vae from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state @@ -22,8 +22,9 @@ import glob import os import re +from modules.ui_components import ToolButton -up_down_arrow_symbol = "\u2195\ufe0f" +fill_values_symbol = "\U0001f4d2" # 📒 def apply_field(field): @@ -178,34 +179,49 @@ def str_permutations(x): """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" return x -AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"]) -AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"]) + +class AxisOption: + def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): + self.label = label + self.type = type + self.apply = apply + self.format_value = format_value + self.confirm = confirm + self.cost = cost + self.choices = choices + self.is_img2img = False + + +class AxisOptionImg2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = False axis_options = [ - AxisOption("Nothing", str, do_nothing, format_nothing, None, 0), - AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0), - AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0), - AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0), - AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0), - AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0), - AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0), - AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0), - AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0), - AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2), - AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0), - AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0), - AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0), - AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0), - AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0), - AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0), - AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0), - AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0), - AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0), - AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0), - AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7), - AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0), + AxisOption("Nothing", str, do_nothing, format_value=format_nothing), + AxisOption("Seed", int, apply_field("seed")), + AxisOption("Var. seed", int, apply_field("subseed")), + AxisOption("Var. strength", float, apply_field("subseed_strength")), + AxisOption("Steps", int, apply_field("steps")), + AxisOption("CFG Scale", float, apply_field("cfg_scale")), + AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), + AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), + AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)), + AxisOption("Hypernet str.", float, apply_hypernetwork_strength), + AxisOption("Sigma Churn", float, apply_field("s_churn")), + AxisOption("Sigma min", float, apply_field("s_tmin")), + AxisOption("Sigma max", float, apply_field("s_tmax")), + AxisOption("Sigma noise", float, apply_field("s_noise")), + AxisOption("Eta", float, apply_field("eta")), + AxisOption("Clip skip", int, apply_clip_skip), + AxisOption("Denoising", float, apply_field("denoising_strength")), + AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]), + AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), + AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), + AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), ] @@ -262,7 +278,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ if not processed_result: print("Unexpected error: draw_xy_grid failed to return even a single processed image") - return Processed() + return Processed(p, []) grid = images.image_grid(image_cache, rows=len(ys)) if draw_legend: @@ -302,23 +318,25 @@ class Script(scripts.Script): return "X/Y plot" def ui(self, is_img2img): - current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] + current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img] with gr.Row(): - with gr.Column(scale=1, elem_id="xy_grid_button_column"): - swap_axes_button = gr.Button(value=up_down_arrow_symbol, elem_id="xy_grid_swap_axes") with gr.Column(scale=19): with gr.Row(): x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) + fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) with gr.Row(): y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - - draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) - no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) + + with gr.Row(variant="compact"): + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") def swap_axes(x_type, x_values, y_type, y_values): nonlocal current_axis_options @@ -327,6 +345,19 @@ class Script(scripts.Script): swap_args = [x_type, x_values, y_type, y_values] swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) + def fill(x_type): + axis = axis_options[x_type] + return ", ".join(axis.choices()) if axis.choices else gr.update() + + fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) + fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) + + def select_axis(x_type): + return gr.Button.update(visible=axis_options[x_type].choices is not None) + + x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) + y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) + return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): diff --git a/style.css b/style.css index 1fddfcc2..97f9402a 100644 --- a/style.css +++ b/style.css @@ -644,7 +644,7 @@ canvas[key="mask"] { max-width: 2.5em; min-width: 2.5em !important; height: 2.4em; - margin: 0.55em 0; + margin: 0.55em 0.7em 0.55em 0; } #quicksettings .gr-button-tool{ @@ -717,16 +717,6 @@ footer { line-height: 2.4em; } -#xy_grid_button_column { - min-width: 38px !important; -} - -#xy_grid_button_column button { - height: 100%; - margin-bottom: 0.7em; - margin-left: 1em; -} - /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 52f6e94338f31c286361802b08ee5210b8244141 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 20:13:23 +0300 Subject: add --skip-install option to prevent running pip in launch.py and speedup launch a bit --- launch.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/launch.py b/launch.py index bcbb792c..715427fd 100644 --- a/launch.py +++ b/launch.py @@ -14,6 +14,7 @@ python = sys.executable git = os.environ.get('GIT', "git") index_url = os.environ.get('INDEX_URL', "") stored_commit_hash = None +skip_install = False def commit_hash(): @@ -89,6 +90,9 @@ def run_python(code, desc=None, errdesc=None): def run_pip(args, desc=None): + if skip_install: + return + index_url_line = f' --index-url {index_url}' if index_url != '' else '' return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") @@ -173,6 +177,8 @@ def run_extensions_installers(settings_file): def prepare_environment(): + global skip_install + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -206,6 +212,7 @@ def prepare_environment(): sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') + sys.argv, skip_install = extract_arg(sys.argv, '--skip-install') xformers = '--xformers' in sys.argv ngrok = '--ngrok' in sys.argv -- cgit v1.2.3 From 9991967f40120b88a1dc925fdf7d747d5e016888 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 22:59:46 +0300 Subject: Add a check and explanation for tensor with all NaNs. --- modules/devices.py | 28 ++++++++++++++++++++++++++++ modules/processing.py | 3 +++ modules/sd_samplers.py | 2 ++ 3 files changed, 33 insertions(+) diff --git a/modules/devices.py b/modules/devices.py index caeb0276..6f034948 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -106,6 +106,33 @@ def autocast(disable=False): return torch.autocast("cuda") +class NansException(Exception): + pass + + +def test_for_nans(x, where): + from modules import shared + + if not torch.all(torch.isnan(x)).item(): + return + + if where == "unet": + message = "A tensor with all NaNs was produced in Unet." + + if not shared.cmd_opts.no_half: + message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this." + + elif where == "vae": + message = "A tensor with all NaNs was produced in VAE." + + if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae: + message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this." + else: + message = "A tensor with all NaNs was produced." + + raise NansException(message) + + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 orig_tensor_to = torch.Tensor.to def tensor_to_fix(self, *args, **kwargs): @@ -156,3 +183,4 @@ if has_mps(): torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) orig_narrow = torch.narrow torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) + diff --git a/modules/processing.py b/modules/processing.py index 849f6b19..ab7b3b7d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -608,6 +608,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] + for x in x_samples_ddim: + devices.test_for_nans(x, "vae") + x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 76e0e0d5..6261d1f7 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -351,6 +351,8 @@ class CFGDenoiser(torch.nn.Module): x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) + devices.test_for_nans(x_out, "unet") + if opts.live_preview_content == "Prompt": store_latent(x_out[0:uncond.shape[0]]) elif opts.live_preview_content == "Negative prompt": -- cgit v1.2.3 From e0e80050091ea7f58ae17c69f31d1b5de5e0ae20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 23:09:08 +0300 Subject: make StableDiffusionProcessing class not hold a reference to shared.sd_model object --- modules/processing.py | 9 +++++---- scripts/xy_grid.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index ab7b3b7d..9c3673de 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -94,7 +94,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): return image_conditioning -class StableDiffusionProcessing(): +class StableDiffusionProcessing: """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing """ @@ -102,7 +102,6 @@ class StableDiffusionProcessing(): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) - self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids self.prompt: str = prompt @@ -156,6 +155,10 @@ class StableDiffusionProcessing(): self.all_subseeds = None self.iteration = 0 + @property + def sd_model(self): + return shared.sd_model + def txt2img_image_conditioning(self, x, width=None, height=None): self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} @@ -236,7 +239,6 @@ class StableDiffusionProcessing(): raise NotImplementedError() def close(self): - self.sd_model = None self.sampler = None @@ -471,7 +473,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model - p.sd_model = shared.sd_model if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index bf4ba92f..6629f5d5 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -86,7 +86,6 @@ def apply_checkpoint(p, x, xs): if info is None: raise RuntimeError(f"Unknown checkpoint: {x}") modules.sd_models.reload_model_weights(shared.sd_model, info) - p.sd_model = shared.sd_model def confirm_checkpoints(p, xs): -- cgit v1.2.3 From c091cf1b4acd2047644d3571bcbfd81c81b4c3af Mon Sep 17 00:00:00 2001 From: Nick Petalas Date: Thu, 22 Dec 2022 22:28:10 +0000 Subject: upgrading torch, torchvision, xformers (windows) to use u117 --- launch.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/launch.py b/launch.py index 715427fd..f51f23f7 100644 --- a/launch.py +++ b/launch.py @@ -179,7 +179,7 @@ def run_extensions_installers(settings_file): def prepare_environment(): global skip_install - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -187,8 +187,6 @@ def prepare_environment(): clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") - xformers_windows_package = os.environ.get('XFORMERS_WINDOWS_PACKAGE', 'https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl') - stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') @@ -239,7 +237,7 @@ def prepare_environment(): if (not is_installed("xformers") or reinstall_xformers) and xformers: if platform.system() == "Windows": if platform.python_version().startswith("3.10"): - run_pip(f"install -U -I --no-deps {xformers_windows_package}", "xformers") + run_pip(f"install -U -I --no-deps xformers==0.0.16rc425", "xformers") else: print("Installation of xformers is not supported in this version of Python.") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") -- cgit v1.2.3 From eb2223340cfdd58efaa157662c279fbf6c90c7d9 Mon Sep 17 00:00:00 2001 From: fuggy <45698918+nonetrix@users.noreply.github.com> Date: Mon, 16 Jan 2023 21:50:30 -0600 Subject: Fix typo --- modules/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/errors.py b/modules/errors.py index a668c014..a10e8708 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -19,7 +19,7 @@ def display(e: Exception, task): message = str(e) if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: print_error_explanation(""" -The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file. +The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its config file. See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this. """) -- cgit v1.2.3 From c361b89026442f3412162657f330d500b803e052 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 17 Jan 2023 11:04:56 +0300 Subject: disable the new NaN check for the CI --- launch.py | 2 ++ modules/devices.py | 3 +++ modules/shared.py | 1 + 3 files changed, 6 insertions(+) diff --git a/launch.py b/launch.py index 715427fd..5afb2956 100644 --- a/launch.py +++ b/launch.py @@ -286,6 +286,8 @@ def tests(test_dir): sys.argv.append("./test/test_files/empty.pt") if "--skip-torch-cuda-test" not in sys.argv: sys.argv.append("--skip-torch-cuda-test") + if "--disable-nan-check" not in sys.argv: + sys.argv.append("--disable-nan-check") print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}") diff --git a/modules/devices.py b/modules/devices.py index 6f034948..206184fb 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -113,6 +113,9 @@ class NansException(Exception): def test_for_nans(x, where): from modules import shared + if shared.cmd_opts.disable_nan_check: + return + if not torch.all(torch.isnan(x)).item(): return diff --git a/modules/shared.py b/modules/shared.py index 483c4c62..a708f23c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -64,6 +64,7 @@ parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") +parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) -- cgit v1.2.3 From 4688bfff55dd6607e6608524fb219f97dc6fe8bb Mon Sep 17 00:00:00 2001 From: dan Date: Tue, 17 Jan 2023 17:16:43 +0800 Subject: Add auto-sized cropping UI --- modules/textual_inversion/preprocess.py | 38 ++++++++++++++++++++++++++++++--- modules/ui.py | 28 +++++++++++++++++++++++- 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 64abff4d..86c1cd33 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop -def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): +def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): try: if process_caption: shared.interrogator.load() @@ -20,7 +20,7 @@ def preprocess(id_task, process_src, process_dst, process_width, process_height, if process_caption_deepbooru: deepbooru.model.start() - preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug) + preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio, process_focal_crop, process_focal_crop_face_weight, process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, process_multicrop, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold) finally: @@ -109,8 +109,32 @@ def split_pic(image, inverse_xy, width, height, overlap_ratio): splitted = image.crop((0, y, to_w, y + to_h)) yield splitted +# not using torchvision.transforms.CenterCrop because it doesn't allow float regions +def center_crop(image: Image, w: int, h: int): + iw, ih = image.size + if ih / h < iw / w: + sw = w * ih / h + box = (iw - sw) / 2, 0, iw - (iw - sw) / 2, ih + else: + sh = h * iw / w + box = 0, (ih - sh) / 2, iw, ih - (ih - sh) / 2 + return image.resize((w, h), Image.Resampling.LANCZOS, box) + -def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): +def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): + iw, ih = image.size + err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h)) + try: + w, h = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64) + if minarea <= w * h <= maxarea and err(w, h) <= threshold), + key= lambda wh: ((objective=='Maximize area')*wh[0]*wh[1], -err(*wh)) + ) + except ValueError: + return + return center_crop(image, w, h) + + +def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): width = process_width height = process_height src = os.path.abspath(process_src) @@ -194,6 +218,14 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre save_pic(focal, index, params, existing_caption=existing_caption) process_default_resize = False + if process_multicrop: + cropped = multicrop_pic(img, process_multicrop_mindim, process_multicrop_maxdim, process_multicrop_minarea, process_multicrop_maxarea, process_multicrop_objective, process_multicrop_threshold) + if cropped is not None: + save_pic(cropped, index, params, existing_caption=existing_caption) + else: + print(f"skipped {img.width}x{img.height} image {filename} (can't find suitable size within error threshold)") + process_default_resize = False + if process_default_resize: img = images.resize_image(1, img, width, height) save_pic(img, index, params, existing_caption=existing_caption) diff --git a/modules/ui.py b/modules/ui.py index 20b66165..bbce9acd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1226,6 +1226,7 @@ def create_ui(): process_flip = gr.Checkbox(label='Create flipped copies', elem_id="train_process_flip") process_split = gr.Checkbox(label='Split oversized images', elem_id="train_process_split") process_focal_crop = gr.Checkbox(label='Auto focal point crop', elem_id="train_process_focal_crop") + process_multicrop = gr.Checkbox(label='Auto-sized crop', elem_id="train_process_multicrop") process_caption = gr.Checkbox(label='Use BLIP for caption', elem_id="train_process_caption") process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True, elem_id="train_process_caption_deepbooru") @@ -1238,7 +1239,19 @@ def create_ui(): process_focal_crop_entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_entropy_weight") process_focal_crop_edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="train_process_focal_crop_edges_weight") process_focal_crop_debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") - + + with gr.Column(visible=False) as process_multicrop_col: + gr.Markdown('Each image is center-cropped with an automatically chosen width and height.') + with gr.Row(): + process_multicrop_mindim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension lower bound", value=384, elem_id="train_process_multicrop_mindim") + process_multicrop_maxdim = gr.Slider(minimum=64, maximum=2048, step=8, label="Dimension upper bound", value=768, elem_id="train_process_multicrop_maxdim") + with gr.Row(): + process_multicrop_minarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area lower bound", value=64*64, elem_id="train_process_multicrop_minarea") + process_multicrop_maxarea = gr.Slider(minimum=64*64, maximum=2048*2048, step=1, label="Area upper bound", value=640*640, elem_id="train_process_multicrop_maxarea") + with gr.Row(): + process_multicrop_objective = gr.Radio(["Maximize area", "Minimize error"], value="Maximize area", label="Resizing objective", elem_id="train_process_multicrop_objective") + process_multicrop_threshold = gr.Slider(minimum=0, maximum=1, step=0.01, label="Error threshold", value=0.1, elem_id="train_process_multicrop_threshold") + with gr.Row(): with gr.Column(scale=3): gr.HTML(value="") @@ -1260,6 +1273,12 @@ def create_ui(): outputs=[process_focal_crop_row], ) + process_multicrop.change( + fn=lambda show: gr_show(show), + inputs=[process_multicrop], + outputs=[process_multicrop_col], + ) + def get_textual_inversion_template_names(): return sorted([x for x in textual_inversion.textual_inversion_templates]) @@ -1379,6 +1398,13 @@ def create_ui(): process_focal_crop_entropy_weight, process_focal_crop_edges_weight, process_focal_crop_debug, + process_multicrop, + process_multicrop_mindim, + process_multicrop_maxdim, + process_multicrop_minarea, + process_multicrop_maxarea, + process_multicrop_objective, + process_multicrop_threshold, ], outputs=[ ti_output, -- cgit v1.2.3 From aede265f1d6d512ca9e51a305e98a96a215366c4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 17 Jan 2023 13:57:55 +0300 Subject: Fix unable to find Real-ESRGAN model info error (AttributeError: 'NoneType' object has no attribute 'data_path') #6841 #5170 --- modules/realesrgan_model.py | 12 ++++-------- modules/upscaler.py | 1 + 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 3ac0b97a..47f70251 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -38,13 +38,13 @@ class UpscalerRealESRGAN(Upscaler): return img info = self.load_model(path) - if not os.path.exists(info.data_path): + if not os.path.exists(info.local_data_path): print("Unable to load RealESRGAN model: %s" % info.name) return img upsampler = RealESRGANer( scale=info.scale, - model_path=info.data_path, + model_path=info.local_data_path, model=info.model(), half=not cmd_opts.no_half, tile=opts.ESRGAN_tile, @@ -58,17 +58,13 @@ class UpscalerRealESRGAN(Upscaler): def load_model(self, path): try: - info = None - for scaler in self.scalers: - if scaler.data_path == path: - info = scaler + info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None) if info is None: print(f"Unable to find model info: {path}") return None - model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True) - info.data_path = model_file + info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True) return info except Exception as e: print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) diff --git a/modules/upscaler.py b/modules/upscaler.py index 231680cb..a5bf5acb 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -95,6 +95,7 @@ class UpscalerData: def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): self.name = name self.data_path = path + self.local_data_path = path self.scaler = upscaler self.scale = scale self.model = model -- cgit v1.2.3 From 38b7186e6e3a4dffc93225308b822f0dae43a47d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 17 Jan 2023 14:15:47 +0300 Subject: update sending input event in java script to not cause exception in browser https://github.com/gradio-app/gradio/issues/2981 --- javascript/edit-attention.js | 5 ++--- javascript/extensions.js | 2 +- javascript/ui.js | 8 ++++++++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index b947cbec..ccc8344a 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -69,7 +69,6 @@ addEventListener('keydown', (event) => { target.selectionStart = selectionStart; target.selectionEnd = selectionEnd; } - // Since we've modified a Gradio Textbox component manually, we need to simulate an `input` DOM event to ensure its - // internal Svelte data binding remains in sync. - target.dispatchEvent(new Event("input", { bubbles: true })); + + updateInput(target) }); diff --git a/javascript/extensions.js b/javascript/extensions.js index 59179ca6..ac6e35b9 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -29,7 +29,7 @@ function install_extension_from_index(button, url){ textarea = gradioApp().querySelector('#extension_to_install textarea') textarea.value = url - textarea.dispatchEvent(new Event("input", { bubbles: true })) + updateInput(textarea) gradioApp().querySelector('#install_extension_button').click() } diff --git a/javascript/ui.js b/javascript/ui.js index ecf97cb3..954beadd 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -278,3 +278,11 @@ function restart_reload(){ return [] } + +// Simulate an `input` DOM event for Gradio Textbox component. Needed after you edit its contents in javascript, otherwise your edits +// will only visible on web page and not sent to python. +function updateInput(target){ + let e = new Event("input", { bubbles: true }) + Object.defineProperty(e, "target", {value: target}) + target.dispatchEvent(e); +} -- cgit v1.2.3 From 6e08da2c315c346225aa834017f4e32cfc0de200 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Tue, 17 Jan 2023 23:50:41 +0900 Subject: Add `--vae-dir` argument --- modules/sd_vae.py | 7 +++++++ modules/shared.py | 1 + 2 files changed, 8 insertions(+) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index b2af2ce7..da1bf15c 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -72,6 +72,13 @@ def refresh_vae_list(): os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), ] + if shared.cmd_opts.vae_dir is not None and os.path.isdir(shared.cmd_opts.vae_dir): + paths += [ + os.path.join(shared.cmd_opts.vae_dir, '**/*.ckpt'), + os.path.join(shared.cmd_opts.vae_dir, '**/*.pt'), + os.path.join(shared.cmd_opts.vae_dir, '**/*.safetensors'), + ] + candidates = [] for path in paths: candidates += glob.iglob(path, recursive=True) diff --git a/modules/shared.py b/modules/shared.py index a708f23c..a1345ad3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -26,6 +26,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") +parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with stable VAE files") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") -- cgit v1.2.3 From 5e15a0b422981c0b5484885d0b4d28af6913c76f Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 17 Jan 2023 11:42:44 -0500 Subject: Changed params.txt save to after manual init call --- modules/processing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 9c3673de..4a1f033e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -538,10 +538,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None: p.scripts.process(p) - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: - processed = Processed(p, [], p.seed, "") - file.write(processed.infotext(p, 0)) - infotexts = [] output_images = [] @@ -572,6 +568,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + processed = Processed(p, [], p.seed, "") + file.write(processed.infotext(p, 0)) + if state.job_count == -1: state.job_count = p.n_iter -- cgit v1.2.3 From 3a0d6b77295162146d0a8d04278804334da6f1b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 17 Jan 2023 23:54:23 +0300 Subject: make it so that PNG images with EXIF do not lose parameters in PNG info tab --- modules/images.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/images.py b/modules/images.py index c3a5fc8b..3b1c5f34 100644 --- a/modules/images.py +++ b/modules/images.py @@ -605,8 +605,9 @@ def read_info_from_image(image): except ValueError: exif_comment = exif_comment.decode('utf8', errors="ignore") - items['exif comment'] = exif_comment - geninfo = exif_comment + if exif_comment: + items['exif comment'] = exif_comment + geninfo = exif_comment for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', 'loop', 'background', 'timestamp', 'duration']: -- cgit v1.2.3 From d906f87043d809e6d4d8de3c9926e184169b330f Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Wed, 18 Jan 2023 07:52:10 +0900 Subject: fix typo --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index a1345ad3..a42279ec 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -26,7 +26,7 @@ parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") -parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with stable VAE files") +parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") -- cgit v1.2.3 From a255dac4f8c5ee11c15b634563d3df513f1834b4 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 12 Jan 2023 08:00:38 -0500 Subject: Fix cumsum for MPS in newer torch The prior fix assumed that testing int16 was enough to determine if a fix is needed, but a recent fix for cumsum has int16 working but not bool. --- modules/devices.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index caeb0276..ac3ae0c9 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -139,8 +139,10 @@ orig_Tensor_cumsum = torch.Tensor.cumsum def cumsum_fix(input, cumsum_func, *args, **kwargs): if input.device.type == 'mps': output_dtype = kwargs.get('dtype', input.dtype) - if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]): + if output_dtype == torch.int64: return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): + return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input, *args, **kwargs) @@ -151,8 +153,9 @@ if has_mps(): torch.nn.functional.layer_norm = layer_norm_fix torch.Tensor.numpy = numpy_fix elif version.parse(torch.__version__) > version.parse("1.13.1"): - if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)): - torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) - torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) + cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) + torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) orig_narrow = torch.narrow torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) -- cgit v1.2.3 From dac59b9b073f86508d3ec787ff731af2e101fbcc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 18 Jan 2023 06:13:45 +0300 Subject: return progress percentage to title bar --- javascript/progressbar.js | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index da6709bc..b8473ebf 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -106,6 +106,19 @@ function formatTime(secs){ } } +function setTitle(progress){ + var title = 'Stable Diffusion' + + if(opts.show_progress_in_title && progress){ + title = '[' + progress.trim() + '] ' + title; + } + + if(document.title != title){ + document.title = title; + } +} + + function randomId(){ return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" } @@ -133,6 +146,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre parentGallery.insertBefore(livePreview, gallery) var removeProgressBar = function(){ + setTitle("") parentProgressbar.removeChild(divProgress) parentGallery.removeChild(livePreview) atEnd() @@ -165,6 +179,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre progressText += " " + res.textinfo } + setTitle(progressText) divInner.textContent = progressText var elapsedFromStart = (new Date() - dateStart) / 1000 -- cgit v1.2.3 From d8f8bcb821fa62e943eb95ee05b8a949317326fe Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 18 Jan 2023 13:20:47 +0300 Subject: enable progressbar without gallery --- javascript/progressbar.js | 24 +++++++++++++++--------- style.css | 19 +++---------------- 2 files changed, 18 insertions(+), 25 deletions(-) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index b8473ebf..18c771a2 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -130,7 +130,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre var dateStart = new Date() var wasEverActive = false var parentProgressbar = progressbarContainer.parentNode - var parentGallery = gallery.parentNode + var parentGallery = gallery ? gallery.parentNode : null var divProgress = document.createElement('div') divProgress.className='progressDiv' @@ -141,14 +141,16 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre divProgress.appendChild(divInner) parentProgressbar.insertBefore(divProgress, progressbarContainer) - var livePreview = document.createElement('div') - livePreview.className='livePreview' - parentGallery.insertBefore(livePreview, gallery) + if(parentGallery){ + var livePreview = document.createElement('div') + livePreview.className='livePreview' + parentGallery.insertBefore(livePreview, gallery) + } var removeProgressBar = function(){ setTitle("") parentProgressbar.removeChild(divProgress) - parentGallery.removeChild(livePreview) + if(parentGallery) parentGallery.removeChild(livePreview) atEnd() } @@ -168,6 +170,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre progressText = "" divInner.style.width = ((res.progress || 0) * 100.0) + '%' + divInner.style.background = res.progress ? "" : "transparent" if(res.progress > 0){ progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%' @@ -175,11 +178,15 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre if(res.eta){ progressText += " ETA: " + formatTime(res.eta) - } else if(res.textinfo){ - progressText += " " + res.textinfo } + setTitle(progressText) + + if(res.textinfo && res.textinfo.indexOf("\n") == -1){ + progressText = res.textinfo + " " + progressText + } + divInner.textContent = progressText var elapsedFromStart = (new Date() - dateStart) / 1000 @@ -197,8 +204,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre } - if(res.live_preview){ - + if(res.live_preview && gallery){ var rect = gallery.getBoundingClientRect() if(rect.width){ livePreview.style.width = rect.width + "px" diff --git a/style.css b/style.css index 97f9402a..b1d47df6 100644 --- a/style.css +++ b/style.css @@ -290,26 +290,12 @@ input[type="range"]{ min-height: unset !important; } -#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ - position: absolute; - z-index: 1000; - right: 0; - padding-left: 5px; - padding-right: 5px; - display: block; -} - -#txt2img_progress_row, #img2img_progress_row{ - margin-bottom: 10px; - margin-top: -18px; -} - .progressDiv{ position: absolute; height: 20px; top: -20px; background: #b4c0cc; - border-radius: 8px !important; + border-radius: 3px !important; } .dark .progressDiv{ @@ -325,9 +311,10 @@ input[type="range"]{ line-height: 20px; padding: 0 8px 0 0; text-align: right; - border-radius: 8px; + border-radius: 3px; overflow: visible; white-space: nowrap; + padding: 0 0.5em; } .livePreview{ -- cgit v1.2.3 From 0c5913b9c28017523011ac6bf83b38ed5de8c11f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 18 Jan 2023 14:14:50 +0300 Subject: re-enable image dragging on non-firefox browsers --- javascript/imageviewer.js | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 1f29ad7b..aac2ee82 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -148,7 +148,15 @@ function showGalleryImage() { if(e && e.parentElement.tagName == 'DIV'){ e.style.cursor='pointer' e.style.userSelect='none' - e.addEventListener('mousedown', function (evt) { + + var isFirefox = isFirefox = navigator.userAgent.toLowerCase().indexOf('firefox') > -1 + + // For Firefox, listening on click first switched to next image then shows the lightbox. + // If you know how to fix this without switching to mousedown event, please. + // For other browsers the event is click to make it possiblr to drag picture. + var event = isFirefox ? 'mousedown' : 'click' + + e.addEventListener(event, function (evt) { if(!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) evt.preventDefault() -- cgit v1.2.3 From 6faae2323963f9b0e0086a85b9d0472a24fbaa73 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 18 Jan 2023 14:33:09 +0300 Subject: repair broken quicksettings when some form-requiring options are added to it --- modules/ui.py | 2 +- style.css | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index e1f98d23..6d70a795 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1659,7 +1659,7 @@ def create_ui(): interfaces += [(extensions_interface, "Extensions", "extensions")] with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings"): + with gr.Row(elem_id="quicksettings", variant="compact"): for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component diff --git a/style.css b/style.css index b6239142..fb58b6c3 100644 --- a/style.css +++ b/style.css @@ -530,7 +530,7 @@ input[type="range"]{ gap: 0.4em; } -#quicksettings > div{ +#quicksettings > div, #quicksettings > fieldset{ max-width: 24em; min-width: 24em; padding: 0; -- cgit v1.2.3 From 05a779b0cd7cfe98d525e70362154a8a4d8b5e09 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 18 Jan 2023 09:47:38 -0500 Subject: fix syntax error --- javascript/localization.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/localization.js b/javascript/localization.js index bf9e1506..1a5a1dbb 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -11,7 +11,7 @@ ignore_ids_for_localization={ train_embedding: 'OPTION', train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION', - img2img_styles 'OPTION', + img2img_styles: 'OPTION', setting_random_artist_categories: 'SPAN', setting_face_restoration_model: 'SPAN', setting_realesrgan_enabled_models: 'SPAN', -- cgit v1.2.3 From 8683427bd9315d2fda0d2f9644c8b1f6a182da55 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Wed, 18 Jan 2023 20:16:52 +0300 Subject: Process interrogation on all img2img subtabs --- javascript/ui.js | 7 +++++++ modules/ui.py | 50 +++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index 954beadd..7d3d57a3 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -109,6 +109,13 @@ function get_extras_tab_index(){ return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args] } +function get_img2img_tab_index() { + let res = args_to_array(arguments) + res.splice(-2) + res[0] = get_tab_index('mode_img2img') + return res +} + function create_submit_args(args){ res = [] for(var i=0;i Date: Wed, 18 Jan 2023 20:29:44 +0300 Subject: make live previews not obscure multiselect dropdowns --- style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/style.css b/style.css index fb58b6c3..61279a19 100644 --- a/style.css +++ b/style.css @@ -148,6 +148,7 @@ #txt2img_styles ul, #img2img_styles ul{ max-height: 35em; + z-index: 2000; } .gr-form{ -- cgit v1.2.3 From 924e222004ab54273806c5f2ca7a0e7cfa76ad83 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 18 Jan 2023 23:04:24 +0300 Subject: add option to show/hide warnings removed hiding warnings from LDSR fixed/reworked few places that produced warnings --- extensions-builtin/LDSR/ldsr_model_arch.py | 3 -- javascript/localization.js | 2 +- modules/hypernetworks/hypernetwork.py | 7 ++++- modules/sd_hijack.py | 8 ------ modules/sd_hijack_checkpoint.py | 38 +++++++++++++++++++++++++- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 6 +++- modules/ui.py | 31 ++++++++++++--------- scripts/prompts_from_file.py | 2 +- style.css | 5 ++-- 10 files changed, 71 insertions(+), 32 deletions(-) diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 0ad49f4e..bc11cc6e 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -1,7 +1,6 @@ import os import gc import time -import warnings import numpy as np import torch @@ -15,8 +14,6 @@ from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap from modules import shared, sd_hijack -warnings.filterwarnings("ignore", category=UserWarning) - cached_ldsr_model: torch.nn.Module = None diff --git a/javascript/localization.js b/javascript/localization.js index bf9e1506..1a5a1dbb 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -11,7 +11,7 @@ ignore_ids_for_localization={ train_embedding: 'OPTION', train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION', - img2img_styles 'OPTION', + img2img_styles: 'OPTION', setting_random_artist_categories: 'SPAN', setting_face_restoration_model: 'SPAN', setting_realesrgan_enabled_models: 'SPAN', diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c963fc40..74e78582 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -12,7 +12,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes +from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -575,6 +575,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi pbar = tqdm.tqdm(total=steps - initial_step) try: + sd_hijack_checkpoint.add() + for i in range((steps-initial_step) * gradient_step): if scheduler.finished: break @@ -724,6 +726,9 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close() hypernetwork.eval() #report_statistics(loss_dict) + sd_hijack_checkpoint.remove() + + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') hypernetwork.optimizer_name = optimizer_name diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6b0d95af..870eba88 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -69,12 +69,6 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -def fix_checkpoint(): - ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward - ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward - ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward - - class StableDiffusionModelHijack: fixes = None comments = [] @@ -106,8 +100,6 @@ class StableDiffusionModelHijack: self.optimization_method = apply_optimizations() self.clip = m.cond_stage_model - - fix_checkpoint() def flatten(el): flattened = [flatten(children) for children in el.children()] diff --git a/modules/sd_hijack_checkpoint.py b/modules/sd_hijack_checkpoint.py index 5712972f..2604d969 100644 --- a/modules/sd_hijack_checkpoint.py +++ b/modules/sd_hijack_checkpoint.py @@ -1,10 +1,46 @@ from torch.utils.checkpoint import checkpoint +import ldm.modules.attention +import ldm.modules.diffusionmodules.openaimodel + + def BasicTransformerBlock_forward(self, x, context=None): return checkpoint(self._forward, x, context) + def AttentionBlock_forward(self, x): return checkpoint(self._forward, x) + def ResBlock_forward(self, x, emb): - return checkpoint(self._forward, x, emb) \ No newline at end of file + return checkpoint(self._forward, x, emb) + + +stored = [] + + +def add(): + if len(stored) != 0: + return + + stored.extend([ + ldm.modules.attention.BasicTransformerBlock.forward, + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward + ]) + + ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward + + +def remove(): + if len(stored) == 0: + return + + ldm.modules.attention.BasicTransformerBlock.forward = stored[0] + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] + + stored.clear() + diff --git a/modules/shared.py b/modules/shared.py index a708f23c..ddb97f99 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -369,6 +369,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration" })) options_templates.update(options_section(('system', "System"), { + "show_warnings": OptionInfo(False, "Show warnings in console."), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 7e4a6d24..5a7be422 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -15,7 +15,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -452,6 +452,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st pbar = tqdm.tqdm(total=steps - initial_step) try: + sd_hijack_checkpoint.add() + for i in range((steps-initial_step) * gradient_step): if scheduler.finished: break @@ -617,9 +619,11 @@ Last saved image: {html.escape(last_saved_image)}
pbar.close() shared.sd_model.first_stage_model.to(devices.device) shared.parallel_processing_allowed = old_parallel_processing_allowed + sd_hijack_checkpoint.remove() return embedding, filename + def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): old_embedding_name = embedding.name old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None diff --git a/modules/ui.py b/modules/ui.py index 6d70a795..25818fb0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -11,6 +11,7 @@ import tempfile import time import traceback from functools import partial, reduce +import warnings import gradio as gr import gradio.routes @@ -41,6 +42,8 @@ from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text +warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) + # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') @@ -417,17 +420,16 @@ def apply_setting(key, value): return value -def update_generation_info(args): - generation_info, html_info, img_index = args +def update_generation_info(generation_info, html_info, img_index): try: generation_info = json.loads(generation_info) if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info - return plaintext_to_html(generation_info["infotexts"][img_index]) + return html_info, gr.update() + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() except Exception: pass # if the json parse or anything else fails, just return the old html_info - return html_info + return html_info, gr.update() def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): @@ -508,10 +510,9 @@ Requested path was: {f} generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, - _js="(x, y) => [x, y, selected_gallery_index()]", - inputs=[generation_info, html_info], - outputs=[html_info], - preprocess=False + _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }", + inputs=[generation_info, html_info, html_info], + outputs=[html_info, html_info], ) save.click( @@ -526,7 +527,8 @@ Requested path was: {f} outputs=[ download_files, html_log, - ] + ], + show_progress=False, ) save_zip.click( @@ -588,7 +590,7 @@ def create_ui(): txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) + txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): @@ -768,7 +770,7 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) with FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): @@ -1768,7 +1770,10 @@ def create_ui(): if saved_value is None: ui_settings[key] = getattr(obj, field) elif condition and not condition(saved_value): - print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') + pass + + # this warning is generally not useful; + # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.') else: setattr(obj, field, saved_value) if init_field is not None: diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index f3e711d7..76dc5778 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -116,7 +116,7 @@ class Script(scripts.Script): checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) - file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=self.elem_id("file")) + file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file")) file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) diff --git a/style.css b/style.css index 61279a19..0845519a 100644 --- a/style.css +++ b/style.css @@ -299,9 +299,8 @@ input[type="range"]{ } /* more gradio's garbage cleanup */ -.min-h-\[4rem\] { - min-height: unset !important; -} +.min-h-\[4rem\] { min-height: unset !important; } +.min-h-\[6rem\] { min-height: unset !important; } .progressDiv{ position: absolute; -- cgit v1.2.3 From b186d44dcd0df9d127a663b297334a5bd8258b58 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 18 Jan 2023 23:20:23 +0300 Subject: use DDIM in hires fix is the sampler is PLMS --- modules/processing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 9c3673de..8c18ac53 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -857,7 +857,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + img2img_sampler_name = self.sampler_name if self.sampler_name != 'PLMS' else 'DDIM' # PLMS does not support img2img so we just silently switch ot DDIM + self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] -- cgit v1.2.3 From bb0978ecfd3177d0bfd7cacd1ac8796d7eec2d79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 00:44:51 +0300 Subject: fix hires fix ui weirdness caused by gradio update --- modules/ui.py | 6 +++--- style.css | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 8b7f1dfb..09a3c92e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -638,7 +638,7 @@ def create_ui(): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') elif category == "checkboxes": - with FormRow(elem_id="txt2img_checkboxes"): + with FormRow(elem_id="txt2img_checkboxes", variant="compact"): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") @@ -646,12 +646,12 @@ def create_ui(): elif category == "hires_fix": with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: - with FormRow(elem_id="txt2img_hires_fix_row1"): + with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - with FormRow(elem_id="txt2img_hires_fix_row2"): + with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") diff --git a/style.css b/style.css index 0845519a..a6abd93d 100644 --- a/style.css +++ b/style.css @@ -686,7 +686,7 @@ footer { #txt2img_checkboxes, #img2img_checkboxes{ margin-bottom: 0.5em; } -#txt2img_checkboxes > div > div, #img2img_checkboxes > div > div{ +#txt2img_checkboxes > div, #img2img_checkboxes > div{ flex: 0; white-space: nowrap; min-width: auto; -- cgit v1.2.3 From 956263b8a4f0393dcb47ed497f367717add4f0e9 Mon Sep 17 00:00:00 2001 From: facu Date: Wed, 18 Jan 2023 19:15:53 -0300 Subject: fixing error using lspci on macOsX --- webui.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index 6e07778f..1edf921d 100755 --- a/webui.sh +++ b/webui.sh @@ -165,7 +165,7 @@ else printf "\n%s\n" "${delimiter}" printf "Launching launch.py..." printf "\n%s\n" "${delimiter}" - gpu_info=$(lspci | grep VGA) + gpu_info=$(lspci 2>/dev/null | grep VGA) if echo "$gpu_info" | grep -q "AMD" then if [[ -z "${TORCH_COMMAND}" ]] -- cgit v1.2.3 From 99207bc816d027b522e1c49001748c63fd426b53 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 18 Jan 2023 19:13:15 -0500 Subject: check model name values are set before merging --- modules/extras.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 22668fcd..29eb1f07 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -287,10 +287,19 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam def add_difference(theta0, theta1_2_diff, alpha): return theta0 + (alpha * theta1_2_diff) + if not primary_model_name: + shared.state.textinfo = "Failed: Merging requires a primary model." + shared.state.end() + return ["Failed: Merging requires a primary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + primary_model_info = sd_models.checkpoints_list[primary_model_name] + + if not secondary_model_name: + shared.state.textinfo = "Failed: Merging requires a secondary model." + shared.state.end() + return ["Failed: Merging requires a secondary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] - tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None) - result_is_inpainting_model = False theta_funcs = { "Weighted sum": (None, weighted_sum), @@ -298,10 +307,15 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam } theta_func1, theta_func2 = theta_funcs[interp_method] - if theta_func1 and not tertiary_model_info: + tertiary_model_info = None + if theta_func1 and not tertiary_model_name: shared.state.textinfo = "Failed: Interpolation method requires a tertiary model." shared.state.end() - return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + return [f"Failed: Interpolation method ({interp_method}) requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + else: + tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None) + + result_is_inpainting_model = False shared.state.textinfo = f"Loading {secondary_model_info.filename}..." print(f"Loading {secondary_model_info.filename}...") -- cgit v1.2.3 From 26a6a78b16f88a6f88f4cca3f378db3b83fc94f8 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Wed, 18 Jan 2023 21:21:52 -0500 Subject: only lookup tertiary model if theta_func1 is set --- modules/extras.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 29eb1f07..88eea22e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -307,13 +307,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam } theta_func1, theta_func2 = theta_funcs[interp_method] - tertiary_model_info = None if theta_func1 and not tertiary_model_name: shared.state.textinfo = "Failed: Interpolation method requires a tertiary model." shared.state.end() return [f"Failed: Interpolation method ({interp_method}) requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] - else: - tertiary_model_info = sd_models.checkpoints_list.get(tertiary_model_name, None) + + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False -- cgit v1.2.3 From 308b51012a5def38edb1c2e127e736c43aa6e1a3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 08:41:37 +0300 Subject: fix an unlikely division by 0 error --- modules/progress.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/modules/progress.py b/modules/progress.py index 3327b883..f9e005d3 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -67,10 +67,13 @@ def progressapi(req: ProgressRequest): progress = 0 - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + job_count, job_no = shared.state.job_count, shared.state.job_no + sampling_steps, sampling_step = shared.state.sampling_steps, shared.state.sampling_step + + if job_count > 0: + progress += job_no / job_count + if sampling_steps > 0: + progress += 1 / job_count * sampling_step / sampling_steps progress = min(progress, 1) -- cgit v1.2.3 From 7cfc6450305125683799208fb7bc27c0b12586b3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 08:53:50 +0300 Subject: eliminate repetition of code in #6910 --- modules/extras.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 88eea22e..367c15cc 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -278,6 +278,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam shared.state.begin() shared.state.job = 'model-merge' + def fail(message): + shared.state.textinfo = message + shared.state.end() + return [message, *[gr.update() for _ in range(4)]] + def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -288,16 +293,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam return theta0 + (alpha * theta1_2_diff) if not primary_model_name: - shared.state.textinfo = "Failed: Merging requires a primary model." - shared.state.end() - return ["Failed: Merging requires a primary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + return fail("Failed: Merging requires a primary model.") primary_model_info = sd_models.checkpoints_list[primary_model_name] if not secondary_model_name: - shared.state.textinfo = "Failed: Merging requires a secondary model." - shared.state.end() - return ["Failed: Merging requires a secondary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + return fail("Failed: Merging requires a secondary model.") secondary_model_info = sd_models.checkpoints_list[secondary_model_name] @@ -308,9 +309,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam theta_func1, theta_func2 = theta_funcs[interp_method] if theta_func1 and not tertiary_model_name: - shared.state.textinfo = "Failed: Interpolation method requires a tertiary model." - shared.state.end() - return [f"Failed: Interpolation method ({interp_method}) requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None -- cgit v1.2.3 From c7e50425f63c07242068f8dcccce70a4ef28a17f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 09:25:37 +0300 Subject: add progress bar to modelmerger --- javascript/ui.js | 11 +++++++++++ modules/extras.py | 18 +++++++++++++++--- modules/progress.py | 2 +- modules/ui.py | 13 ++++++++----- style.css | 5 +++++ 5 files changed, 40 insertions(+), 9 deletions(-) diff --git a/javascript/ui.js b/javascript/ui.js index 7d3d57a3..428375d4 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -172,6 +172,17 @@ function submit_img2img(){ return res } +function modelmerger(){ + var id = randomId() + requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){}) + + gradioApp().getElementById('modelmerger_result').innerHTML = '' + + var res = create_submit_args(arguments) + res[0] = id + return res +} + function ask_for_style_name(_, prompt_text, negative_prompt_text) { name_ = prompt('Style name:') diff --git a/modules/extras.py b/modules/extras.py index 367c15cc..034f28e4 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename) -def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): shared.state.begin() shared.state.job = 'model-merge' + shared.state.job_count = 1 def fail(message): shared.state.textinfo = message shared.state.end() - return [message, *[gr.update() for _ in range(4)]] + return [*[gr.update() for _ in range(4)], message] def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') if theta_func1: + shared.state.job_count += 1 + print(f"Loading {tertiary_model_info.filename}...") theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + shared.state.sampling_steps = len(theta_1.keys()) for key in tqdm.tqdm(theta_1.keys()): if 'model' in key: if key in theta_2: @@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam theta_1[key] = theta_func1(theta_1[key], t2) else: theta_1[key] = torch.zeros_like(theta_1[key]) + + shared.state.sampling_step += 1 del theta_2 + shared.state.nextjob() + shared.state.textinfo = f"Loading {primary_model_info.filename}..." print(f"Loading {primary_model_info.filename}...") theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') @@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + shared.state.sampling_steps = len(theta_0.keys()) for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: @@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam if save_as_half: theta_0[key] = theta_0[key].half() + shared.state.sampling_step += 1 + # I believe this part should be discarded, but I'll leave it for now until I am sure for key in theta_1.keys(): if 'model' in key and key not in theta_0: @@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam output_modelname = os.path.join(ckpt_dir, filename) + shared.state.nextjob() shared.state.textinfo = f"Saving to {output_modelname}..." print(f"Saving to {output_modelname}...") @@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam shared.state.textinfo = "Checkpoint saved to " + output_modelname shared.state.end() - return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)] + return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/progress.py b/modules/progress.py index f9e005d3..c69ecf3d 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -72,7 +72,7 @@ def progressapi(req: ProgressRequest): if job_count > 0: progress += job_no / job_count - if sampling_steps > 0: + if sampling_steps > 0 and job_count > 0: progress += 1 / job_count * sampling_step / sampling_steps progress = min(progress, 1) diff --git a/modules/ui.py b/modules/ui.py index 09a3c92e..aeee7853 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1208,8 +1208,9 @@ def create_ui(): with gr.Row(): modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') - with gr.Column(variant='panel'): - submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + with gr.Column(variant='compact', elem_id="modelmerger_results_container"): + with gr.Group(elem_id="modelmerger_results_panel"): + modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False) with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Row().style(equal_height=False): @@ -1753,12 +1754,14 @@ def create_ui(): print("Error loading/saving model file:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) modules.sd_models.list_models() # to remove the potentially missing models from the list - return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)] + return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results modelmerger_merge.click( - fn=modelmerger, + fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), + _js='modelmerger', inputs=[ + dummy_component, primary_model_name, secondary_model_name, tertiary_model_name, @@ -1770,11 +1773,11 @@ def create_ui(): config_source, ], outputs=[ - submit_result, primary_model_name, secondary_model_name, tertiary_model_name, component_dict['sd_model_checkpoint'], + modelmerger_result, ] ) diff --git a/style.css b/style.css index a6abd93d..32ba4753 100644 --- a/style.css +++ b/style.css @@ -737,6 +737,11 @@ footer { line-height: 2.4em; } +#modelmerger_results_container{ + margin-top: 1em; + overflow: visible; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 10:39:51 +0300 Subject: allow baking in VAE in checkpoint merger tab do not save config if it's the default for checkpoint merger tab change file naming scheme for checkpoint merger tab allow just saving A without any merging for checkpoint merger tab some stylistic changes for UI in checkpoint merger tab --- javascript/hints.js | 1 + javascript/ui.js | 2 - modules/extras.py | 112 +++++++++++++++++++++++++++++++--------------------- modules/sd_vae.py | 9 ++++- modules/shared.py | 3 +- modules/ui.py | 17 ++++++-- style.css | 15 ++++--- 7 files changed, 101 insertions(+), 58 deletions(-) diff --git a/javascript/hints.js b/javascript/hints.js index fa5e5ae8..e746e20d 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -92,6 +92,7 @@ titles = { "Weighted sum": "Result = A * (1 - M) + B * M", "Add difference": "Result = A + (B - C) * M", + "No interpolation": "Result = A", "Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors", "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", diff --git a/javascript/ui.js b/javascript/ui.js index 428375d4..37788a3e 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -176,8 +176,6 @@ function modelmerger(){ var id = randomId() requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){}) - gradioApp().getElementById('modelmerger_result').innerHTML = '' - var res = create_submit_args(arguments) res[0] = id return res diff --git a/modules/extras.py b/modules/extras.py index 034f28e4..fe701a0e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple from functools import partial from dataclasses import dataclass -from modules import processing, shared, images, devices, sd_models, sd_samplers +from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae from modules.shared import opts import modules.gfpgan_model from modules.ui import plaintext_to_html @@ -251,7 +251,8 @@ def run_pnginfo(image): def create_config(ckpt_result, config_source, a, b, c): def config(x): - return sd_models.find_checkpoint_config(x) if x else None + res = sd_models.find_checkpoint_config(x) if x else None + return res if res != shared.sd_default_config else None if config_source == 0: cfg = config(a) or config(b) or config(c) @@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename) -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): +chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae): shared.state.begin() shared.state.job = 'model-merge' - shared.state.job_count = 1 def fail(message): shared.state.textinfo = message @@ -293,41 +296,68 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ def add_difference(theta0, theta1_2_diff, alpha): return theta0 + (alpha * theta1_2_diff) + def filename_weighed_sum(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + Ma = round(1 - multiplier, 2) + Mb = round(multiplier, 2) + + return f"{Ma}({a}) + {Mb}({b})" + + def filename_add_differnece(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + c = tertiary_model_info.model_name + M = round(multiplier, 2) + + return f"{a} + {M}({b} - {c})" + + def filename_nothing(): + return primary_model_info.model_name + + theta_funcs = { + "Weighted sum": (filename_weighed_sum, None, weighted_sum), + "Add difference": (filename_add_differnece, get_difference, add_difference), + "No interpolation": (filename_nothing, None, None), + } + filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] + shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) + if not primary_model_name: return fail("Failed: Merging requires a primary model.") primary_model_info = sd_models.checkpoints_list[primary_model_name] - if not secondary_model_name: + if theta_func2 and not secondary_model_name: return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] - theta_funcs = { - "Weighted sum": (None, weighted_sum), - "Add difference": (get_difference, add_difference), - } - theta_func1, theta_func2 = theta_funcs[interp_method] + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None if theta_func1 and not tertiary_model_name: return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False - shared.state.textinfo = f"Loading {secondary_model_info.filename}..." - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + if theta_func2: + shared.state.textinfo = f"Loading B" + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + else: + theta_1 = None if theta_func1: - shared.state.job_count += 1 - + shared.state.textinfo = f"Loading C" print(f"Loading {tertiary_model_info.filename}...") theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + shared.state.textinfo = 'Merging B and C' shared.state.sampling_steps = len(theta_1.keys()) for key in tqdm.tqdm(theta_1.keys()): + if key in chckpoint_dict_skip_on_merge: + continue + if 'model' in key: if key in theta_2: t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) @@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') print("Merging...") - - chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - + shared.state.textinfo = 'Merging A and B' shared.state.sampling_steps = len(theta_0.keys()) for key in tqdm.tqdm(theta_0.keys()): - if 'model' in key and key in theta_1: + if theta_1 and 'model' in key and key in theta_1: if key in chckpoint_dict_skip_on_merge: continue @@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ a = theta_0[key] b = theta_1[key] - shared.state.textinfo = f'Merging layer {key}' # this enables merging an inpainting model (A) with another one (B); # where normal model would have 4 channels, for latenst space, inpainting model would # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 @@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.sampling_step += 1 - # I believe this part should be discarded, but I'll leave it for now until I am sure - for key in theta_1.keys(): - if 'model' in key and key not in theta_0: + del theta_1 + + bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) + if bake_in_vae_filename is not None: + print(f"Baking in VAE from {bake_in_vae_filename}") + shared.state.textinfo = 'Baking in VAE' + vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - if key in chckpoint_dict_skip_on_merge: - continue + for key in vae_dict.keys(): + theta_0_key = 'first_stage_model.' + key + if theta_0_key in theta_0: + theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key] - theta_0[key] = theta_1[key] - if save_as_half: - theta_0[key] = theta_0[key].half() - del theta_1 + del vae_dict ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - filename = \ - primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \ - secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \ - interp_method.replace(" ", "_") + \ - '-merged.' + \ - ("inpainting." if result_is_inpainting_model else "") + \ - checkpoint_format - - filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format) + filename = filename_generator() if custom_name == '' else custom_name + filename += ".inpainting" if result_is_inpainting_model else "" + filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) shared.state.nextjob() - shared.state.textinfo = f"Saving to {output_modelname}..." + shared.state.textinfo = "Saving" print(f"Saving to {output_modelname}...") _, extension = os.path.splitext(output_modelname) @@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - print("Checkpoint saved.") - shared.state.textinfo = "Checkpoint saved to " + output_modelname + print(f"Checkpoint saved to {output_modelname}.") + shared.state.textinfo = "Checkpoint saved" shared.state.end() return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/sd_vae.py b/modules/sd_vae.py index da1bf15c..4ce238b8 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file): return None, None +def load_vae_dict(filename, map_location): + vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + return vae_dict_1 + + def load_vae(model, vae_file=None, vae_source="from unknown source"): global vae_dict, loaded_vae_file # save_settings = False @@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) - vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) _load_vae_dict(model, vae_dict_1) if cache_enabled: diff --git a/modules/shared.py b/modules/shared.py index 77e5e91c..29b28bff 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -20,10 +20,11 @@ from modules.paths import models_path, script_path, sd_path demo = None +sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() -parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",) +parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files") diff --git a/modules/ui.py b/modules/ui.py index aeee7853..4e381a49 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -1185,7 +1185,7 @@ def create_ui(): with gr.Column(variant='compact'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - with FormRow(): + with FormRow(elem_id="modelmerger_models"): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") @@ -1197,13 +1197,20 @@ def create_ui(): custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") - interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") with FormRow(): checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + with FormRow(): + with gr.Column(): + config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + with gr.Column(): + with FormRow(): + bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") + create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") with gr.Row(): modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') @@ -1757,6 +1764,7 @@ def create_ui(): return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results + modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result]) modelmerger_merge.click( fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), _js='modelmerger', @@ -1771,6 +1779,7 @@ def create_ui(): custom_name, checkpoint_format, config_source, + bake_in_vae, ], outputs=[ primary_model_name, diff --git a/style.css b/style.css index 32ba4753..c10e32a1 100644 --- a/style.css +++ b/style.css @@ -641,6 +641,16 @@ canvas[key="mask"] { margin: 0.6em 0em 0.55em 0; } +#modelmerger_results_container{ + margin-top: 1em; + overflow: visible; +} + +#modelmerger_models{ + gap: 0; +} + + #quicksettings .gr-button-tool{ margin: 0; } @@ -737,11 +747,6 @@ footer { line-height: 2.4em; } -#modelmerger_results_container{ - margin-top: 1em; - overflow: visible; -} - /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 54674674b813894b908283531ddaab4ccfeac721 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 12:12:09 +0300 Subject: allow having at half precision when there is only one checkpoint in merger tab --- modules/extras.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index fe701a0e..d03f976e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -278,6 +278,13 @@ def create_config(ckpt_result, config_source, a, b, c): chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] +def to_half(tensor, enable): + if enable and tensor.dtype == torch.float: + return tensor.half() + + return tensor + + def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae): shared.state.begin() shared.state.job = 'model-merge' @@ -400,8 +407,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ else: theta_0[key] = theta_func2(a, b, multiplier) - if save_as_half: - theta_0[key] = theta_0[key].half() + theta_0[key] = to_half(theta_0[key], save_as_half) shared.state.sampling_step += 1 @@ -416,10 +422,14 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ for key in vae_dict.keys(): theta_0_key = 'first_stage_model.' + key if theta_0_key in theta_0: - theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key] + theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) del vae_dict + if save_as_half and not theta_func2: + for key in theta_0.keys(): + theta_0[key] = to_half(theta_0[key], save_as_half) + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path filename = filename_generator() if custom_name == '' else custom_name -- cgit v1.2.3 From 18a09c7e0032e2e655269e8e2b4f1ca6ed0cc7d3 Mon Sep 17 00:00:00 2001 From: dan Date: Thu, 19 Jan 2023 17:36:23 +0800 Subject: Simplification and bugfix --- modules/textual_inversion/preprocess.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 86c1cd33..454dcc36 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -124,13 +124,11 @@ def center_crop(image: Image, w: int, h: int): def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): iw, ih = image.size err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h)) - try: - w, h = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64) - if minarea <= w * h <= maxarea and err(w, h) <= threshold), - key= lambda wh: ((objective=='Maximize area')*wh[0]*wh[1], -err(*wh)) - ) - except ValueError: - return + w, h = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64) + if minarea <= w * h <= maxarea and err(w, h) <= threshold), + key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1], + default=None + ) return center_crop(image, w, h) -- cgit v1.2.3 From 2985b317d719f0f0580d2ff93f3008ccabb9c251 Mon Sep 17 00:00:00 2001 From: dan Date: Thu, 19 Jan 2023 17:39:30 +0800 Subject: Fix of fix --- modules/textual_inversion/preprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 454dcc36..c0ac11d3 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -124,12 +124,12 @@ def center_crop(image: Image, w: int, h: int): def multicrop_pic(image: Image, mindim, maxdim, minarea, maxarea, objective, threshold): iw, ih = image.size err = lambda w, h: 1-(lambda x: x if x < 1 else 1/x)(iw/ih/(w/h)) - w, h = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64) + wh = max(((w, h) for w in range(mindim, maxdim+1, 64) for h in range(mindim, maxdim+1, 64) if minarea <= w * h <= maxarea and err(w, h) <= threshold), key= lambda wh: (wh[0]*wh[1], -err(*wh))[::1 if objective=='Maximize area' else -1], default=None ) - return center_crop(image, w, h) + return wh and center_crop(image, *wh) def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False, process_multicrop=None, process_multicrop_mindim=None, process_multicrop_maxdim=None, process_multicrop_minarea=None, process_multicrop_maxarea=None, process_multicrop_objective=None, process_multicrop_threshold=None): -- cgit v1.2.3 From b271e22f7ac1b2cabca8985b1e4437ab685a2c21 Mon Sep 17 00:00:00 2001 From: vt-idiot <81622808+vt-idiot@users.noreply.github.com> Date: Thu, 19 Jan 2023 06:12:19 -0500 Subject: Update shared.py `Witdth/Height` was driving me insane. -> `Width/Height` --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/shared.py b/modules/shared.py index 29b28bff..2f366454 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -448,7 +448,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), - "dimensions_and_batch_together": OptionInfo(True, "Show Witdth/Height and Batch sliders in same row"), + "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), -- cgit v1.2.3 From c12d7ddd725c485682c1caa025627c9ee936d743 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 15:58:32 +0300 Subject: add handling to some places in javascript that can potentially cause issues #6898 --- .../javascript/prompt-bracket-checker.js | 10 ++++++---- javascript/progressbar.js | 9 +++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index eccfb0f9..251a1f57 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -93,10 +93,12 @@ function checkBrackets(evt) { } var shadowRootLoaded = setInterval(function() { - var shadowTextArea = document.querySelector('gradio-app').shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); - if(shadowTextArea.length < 1) { - return false; - } + var sahdowRoot = document.querySelector('gradio-app').shadowRoot; + if(! sahdowRoot) return false; + + var shadowTextArea = sahdowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); + if(shadowTextArea.length < 1) return false; + clearInterval(shadowRootLoaded); diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 18c771a2..2514d2e2 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -81,8 +81,13 @@ function request(url, data, handler, errorHandler){ xhr.onreadystatechange = function () { if (xhr.readyState === 4) { if (xhr.status === 200) { - var js = JSON.parse(xhr.responseText); - handler(js) + try { + var js = JSON.parse(xhr.responseText); + handler(js) + } catch (error) { + console.error(error); + errorHandler() + } } else{ errorHandler() } -- cgit v1.2.3 From 81276cde90ebecfab317cc62a0100d298c3c43c4 Mon Sep 17 00:00:00 2001 From: poiuty Date: Thu, 19 Jan 2023 16:56:45 +0300 Subject: internal progress relative path --- javascript/progressbar.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 2514d2e2..ff6d757b 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -160,7 +160,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre } var fun = function(id_task, id_live_preview){ - request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){ + request("./internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){ if(res.completed){ removeProgressBar() return -- cgit v1.2.3 From d1ea518dea3d7584be2927cc486d15ec3e18ddb0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 18:07:37 +0300 Subject: remember the list of checkpoints after you press refresh button and reload the page --- modules/ui.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index af416d5f..0c5ba358 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1771,8 +1771,17 @@ def create_ui(): component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + def get_value_for_setting(key): + value = getattr(opts, key) + + info = opts.data_labels[key] + args = info.component_args() if callable(info.component_args) else info.component_args or {} + args = {k: v for k, v in args.items() if k not in {'precision'}} + + return gr.update(value=value, **args) + def get_settings_values(): - return [getattr(opts, key) for key in component_keys] + return [get_value_for_setting(key) for key in component_keys] demo.load( fn=get_settings_values, -- cgit v1.2.3 From f2ae2529877072874ebaac0257fe4af48c5855a4 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Thu, 19 Jan 2023 10:24:17 -0500 Subject: fixes minor typos around run_modelmerger --- modules/extras.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index d03f976e..1218f88f 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -275,7 +275,7 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename) -chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] +checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] def to_half(tensor, enable): @@ -303,7 +303,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ def add_difference(theta0, theta1_2_diff, alpha): return theta0 + (alpha * theta1_2_diff) - def filename_weighed_sum(): + def filename_weighted_sum(): a = primary_model_info.model_name b = secondary_model_info.model_name Ma = round(1 - multiplier, 2) @@ -311,7 +311,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ return f"{Ma}({a}) + {Mb}({b})" - def filename_add_differnece(): + def filename_add_difference(): a = primary_model_info.model_name b = secondary_model_info.model_name c = tertiary_model_info.model_name @@ -323,8 +323,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ return primary_model_info.model_name theta_funcs = { - "Weighted sum": (filename_weighed_sum, None, weighted_sum), - "Add difference": (filename_add_differnece, get_difference, add_difference), + "Weighted sum": (filename_weighted_sum, None, weighted_sum), + "Add difference": (filename_add_difference, get_difference, add_difference), "No interpolation": (filename_nothing, None, None), } filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] @@ -362,7 +362,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = 'Merging B and C' shared.state.sampling_steps = len(theta_1.keys()) for key in tqdm.tqdm(theta_1.keys()): - if key in chckpoint_dict_skip_on_merge: + if key in checkpoint_dict_skip_on_merge: continue if 'model' in key: @@ -387,7 +387,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ for key in tqdm.tqdm(theta_0.keys()): if theta_1 and 'model' in key and key in theta_1: - if key in chckpoint_dict_skip_on_merge: + if key in checkpoint_dict_skip_on_merge: continue a = theta_0[key] -- cgit v1.2.3 From c1928cdd6194928af0f53f70c51d59479b7025e2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 18:58:08 +0300 Subject: bring back short hashes to sd checkpoint selection --- modules/sd_models.py | 15 +++++++++++---- modules/ui.py | 23 ++++++++++++----------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/modules/sd_models.py b/modules/sd_models.py index 6a681cef..12083848 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -41,14 +41,16 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"): name = name[1:] - self.title = name + self.name = name self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) - self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title) + self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name) self.shorthash = self.sha256[0:10] if self.sha256 else None - self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else []) + self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self @@ -56,13 +58,15 @@ class CheckpointInfo: checkpoint_alisases[id] = self def calculate_shorthash(self): - self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title) + self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name) self.shorthash = self.sha256[0:10] if self.shorthash not in self.ids: self.ids += [self.shorthash, self.sha256] self.register() + self.title = f'{self.name} [{self.shorthash}]' + return self.shorthash @@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None def load_model_weights(model, checkpoint_info: CheckpointInfo): + title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + if checkpoint_info.title != title: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title cache_enabled = shared.opts.sd_checkpoint_cache > 0 diff --git a/modules/ui.py b/modules/ui.py index 0c5ba358..13d80ae2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -439,7 +439,7 @@ def apply_setting(key, value): opts.data_labels[key].onchange() opts.save(shared.config_filename) - return value + return getattr(opts, key) def update_generation_info(generation_info, html_info, img_index): @@ -597,6 +597,16 @@ def ordered_ui_categories(): yield category +def get_value_for_setting(key): + value = getattr(opts, key) + + info = opts.data_labels[key] + args = info.component_args() if callable(info.component_args) else info.component_args or {} + args = {k: v for k, v in args.items() if k not in {'precision'}} + + return gr.update(value=value, **args) + + def create_ui(): import modules.img2img import modules.txt2img @@ -1600,7 +1610,7 @@ def create_ui(): opts.save(shared.config_filename) - return gr.update(value=value), opts.dumpjson() + return get_value_for_setting(key), opts.dumpjson() with gr.Blocks(analytics_enabled=False) as settings_interface: with gr.Row(): @@ -1771,15 +1781,6 @@ def create_ui(): component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - def get_value_for_setting(key): - value = getattr(opts, key) - - info = opts.data_labels[key] - args = info.component_args() if callable(info.component_args) else info.component_args or {} - args = {k: v for k, v in args.items() if k not in {'precision'}} - - return gr.update(value=value, **args) - def get_settings_values(): return [get_value_for_setting(key) for key in component_keys] -- cgit v1.2.3 From 4599e8ad0acaae3f13bd3a7bef4db7632aac8504 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Thu, 19 Jan 2023 17:00:51 +0100 Subject: Environment variable on launch just for Navi cards Setting HSA_OVERRIDE_GFX_VERSION=10.3.0 for all AMD cards seems to break compatibility for polaris and vega cards so it should just be enabled on Navi --- webui.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index 1edf921d..a35a5f35 100755 --- a/webui.sh +++ b/webui.sh @@ -172,7 +172,12 @@ else then export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" fi - HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" + if echo "$gpu_info" | grep -q "Navi" + then + HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" + else + exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" + fi else exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" fi -- cgit v1.2.3 From 6073456c8348d15716b9bc5276d994fe8554e4ca Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 20:39:03 +0300 Subject: write a comment for fix_checkpoint function --- modules/sd_hijack.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 870eba88..f9652d21 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -69,6 +69,13 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward +def fix_checkpoint(): + """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want + checkpoints to be added when not training (there's a warning)""" + + pass + + class StableDiffusionModelHijack: fixes = None comments = [] -- cgit v1.2.3 From c09fb3d8f1f71bc66d7c4cea603885619d6a1cd4 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Thu, 19 Jan 2023 19:21:02 +0100 Subject: Simplify GPU check --- webui.sh | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/webui.sh b/webui.sh index a35a5f35..aa4f875c 100755 --- a/webui.sh +++ b/webui.sh @@ -104,6 +104,12 @@ then fi # Check prerequisites +gpu_info=$(lspci 2>/dev/null | grep VGA) +if echo "$gpu_info" | grep -q "Navi" +then + export HSA_OVERRIDE_GFX_VERSION=10.3.0 +fi + for preq in "${GIT}" "${python_cmd}" do if ! hash "${preq}" &>/dev/null @@ -165,20 +171,9 @@ else printf "\n%s\n" "${delimiter}" printf "Launching launch.py..." printf "\n%s\n" "${delimiter}" - gpu_info=$(lspci 2>/dev/null | grep VGA) - if echo "$gpu_info" | grep -q "AMD" + if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] then - if [[ -z "${TORCH_COMMAND}" ]] - then - export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" - fi - if echo "$gpu_info" | grep -q "Navi" - then - HSA_OVERRIDE_GFX_VERSION=10.3.0 exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" - else - exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" - fi - else - exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" - fi + export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" + fi + exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" fi -- cgit v1.2.3 From 48045545d9a3f174621a62086812d9bbfb3ce1c2 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Thu, 19 Jan 2023 19:23:40 +0100 Subject: Small reformat of the GPU check --- webui.sh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/webui.sh b/webui.sh index aa4f875c..ff410e15 100755 --- a/webui.sh +++ b/webui.sh @@ -109,6 +109,10 @@ if echo "$gpu_info" | grep -q "Navi" then export HSA_OVERRIDE_GFX_VERSION=10.3.0 fi +if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] +then + export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" +fi for preq in "${GIT}" "${python_cmd}" do @@ -170,10 +174,6 @@ then else printf "\n%s\n" "${delimiter}" printf "Launching launch.py..." - printf "\n%s\n" "${delimiter}" - if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] - then - export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" - fi + printf "\n%s\n" "${delimiter}" exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" fi -- cgit v1.2.3 From 36364bd76c4634820e08070a287f0a5ad27c35f6 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Thu, 19 Jan 2023 20:05:49 +0100 Subject: GFX env just for RDNA 1 and 2 This commit specifies which GPUs should use the GFX variable, RDNA 3 is excluded since it uses a newer GFX version --- webui.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index ff410e15..91c95e47 100755 --- a/webui.sh +++ b/webui.sh @@ -105,7 +105,7 @@ fi # Check prerequisites gpu_info=$(lspci 2>/dev/null | grep VGA) -if echo "$gpu_info" | grep -q "Navi" +if echo "$gpu_info" | grep -qE "Navi (1|2)" then export HSA_OVERRIDE_GFX_VERSION=10.3.0 fi -- cgit v1.2.3 From 912285ae64e4e1186feb54caf82b4a0b11c6cb7f Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Thu, 19 Jan 2023 23:42:12 +0100 Subject: Experimental support for Renoir This adds the GFX version 9.0.0 in order to use Renoir GPUs with at least 4 GB of VRAM (it's possible to increase the virtual VRAM from the BIOS settings of some vendors). This will only work if the remaining ram is at least 12 GB to avoid the system to become unresponsive on launch.). This change also changes the GPU check to a case statement to be able to add more GPUs efficiently. --- webui.sh | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/webui.sh b/webui.sh index 91c95e47..27933c04 100755 --- a/webui.sh +++ b/webui.sh @@ -105,10 +105,14 @@ fi # Check prerequisites gpu_info=$(lspci 2>/dev/null | grep VGA) -if echo "$gpu_info" | grep -qE "Navi (1|2)" -then - export HSA_OVERRIDE_GFX_VERSION=10.3.0 -fi +case "$gpu_info" in + *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 + ;; + *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 + ;; + *) + ;; +esac if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] then export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2" -- cgit v1.2.3 From 0684a6819dfaec40732271ca5ef32392c36f17ba Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Fri, 20 Jan 2023 00:21:05 +0100 Subject: Usage explanation for Renoir users --- webui.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/webui.sh b/webui.sh index 27933c04..4da51880 100755 --- a/webui.sh +++ b/webui.sh @@ -109,6 +109,9 @@ case "$gpu_info" in *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 + printf "\n%s\n" "${delimiter}" + printf "Make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" + printf "\n%s\n" "${delimiter}" ;; *) ;; -- cgit v1.2.3 From fd651bd0bcceb4c746c86b202702bca029cbd6db Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Fri, 20 Jan 2023 00:21:51 +0100 Subject: Update webui.sh --- webui.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/webui.sh b/webui.sh index 4da51880..d5e7b3c5 100755 --- a/webui.sh +++ b/webui.sh @@ -110,8 +110,8 @@ case "$gpu_info" in ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" - printf "Make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" - printf "\n%s\n" "${delimiter}" + printf "Make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" + printf "\n%s\n" "${delimiter}" ;; *) ;; -- cgit v1.2.3 From 6c7a50d783c4e406d8597f9cf354bb8128026f6c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 20 Jan 2023 08:36:30 +0300 Subject: remove some unnecessary logging to javascript console --- javascript/hires_fix.js | 3 --- modules/ui.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/javascript/hires_fix.js b/javascript/hires_fix.js index 07fba549..0629475f 100644 --- a/javascript/hires_fix.js +++ b/javascript/hires_fix.js @@ -1,6 +1,5 @@ function setInactive(elem, inactive){ - console.log(elem) if(inactive){ elem.classList.add('inactive') } else{ @@ -9,8 +8,6 @@ function setInactive(elem, inactive){ } function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){ - console.log(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) - hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale') hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x') hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y') diff --git a/modules/ui.py b/modules/ui.py index 13d80ae2..eb45a128 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -532,7 +532,7 @@ Requested path was: {f} generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( fn=update_generation_info, - _js="function(x, y, z){ console.log(x, y, z); return [x, y, selected_gallery_index()] }", + _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", inputs=[generation_info, html_info, html_info], outputs=[html_info, html_info], ) -- cgit v1.2.3 From 98466da4bc312c0fa9c8cea4c825afc64194cb58 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Fri, 20 Jan 2023 00:48:15 -0500 Subject: adds descriptions for merging methods in ui --- modules/ui.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index eb45a128..ee434bde 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1190,10 +1190,19 @@ def create_ui(): outputs=[html, generation_info, html2], ) + def update_interp_description(value): + interp_description_css = "

{}

" + interp_descriptions = { + "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."), + "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"), + "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M") + } + return interp_descriptions[value] + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='compact'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description") with FormRow(elem_id="modelmerger_models"): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") @@ -1208,6 +1217,7 @@ def create_ui(): custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description]) with FormRow(): checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") @@ -1903,6 +1913,9 @@ def create_ui(): with open(ui_config_file, "w", encoding="utf8") as file: json.dump(ui_settings, file, indent=4) + # Required as a workaround for change() event not triggering when loading values from ui-config.json + interp_description.value = update_interp_description(interp_method.value) + return demo -- cgit v1.2.3 From 20a59ab3b171f398abd09087108c1ed087dbea9b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 20 Jan 2023 10:18:41 +0300 Subject: move token counter to the location of the prompt, add token counting for the negative prompt --- .../javascript/prompt-bracket-checker.js | 45 +++++++++++---------- javascript/ui.js | 29 +++++++++---- modules/ui.py | 25 ++++++------ style.css | 47 ++++++++++++++-------- 4 files changed, 87 insertions(+), 59 deletions(-) diff --git a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js index 251a1f57..4a85c8eb 100644 --- a/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js +++ b/extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js @@ -4,16 +4,10 @@ // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. -function checkBrackets(evt) { - textArea = evt.target; - tabName = evt.target.parentElement.parentElement.id.split("_")[0]; - counterElt = document.querySelector('gradio-app').shadowRoot.querySelector('#' + tabName + '_token_counter'); - - promptName = evt.target.parentElement.parentElement.id.includes('neg') ? ' negative' : ''; - - errorStringParen = '(' + tabName + promptName + ' prompt) - Different number of opening and closing parentheses detected.\n'; - errorStringSquare = '[' + tabName + promptName + ' prompt] - Different number of opening and closing square brackets detected.\n'; - errorStringCurly = '{' + tabName + promptName + ' prompt} - Different number of opening and closing curly brackets detected.\n'; +function checkBrackets(evt, textArea, counterElt) { + errorStringParen = '(...) - Different number of opening and closing parentheses detected.\n'; + errorStringSquare = '[...] - Different number of opening and closing square brackets detected.\n'; + errorStringCurly = '{...} - Different number of opening and closing curly brackets detected.\n'; openBracketRegExp = /\(/g; closeBracketRegExp = /\)/g; @@ -86,24 +80,31 @@ function checkBrackets(evt) { } if(counterElt.title != '') { - counterElt.style = 'color: #FF5555;'; + counterElt.classList.add('error'); } else { - counterElt.style = ''; + counterElt.classList.remove('error'); } } -var shadowRootLoaded = setInterval(function() { - var sahdowRoot = document.querySelector('gradio-app').shadowRoot; - if(! sahdowRoot) return false; +function setupBracketChecking(id_prompt, id_counter){ + var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); + var counter = gradioApp().getElementById(id_counter) + textarea.addEventListener("input", function(evt){ + checkBrackets(evt, textarea, counter) + }); +} - var shadowTextArea = sahdowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); - if(shadowTextArea.length < 1) return false; +var shadowRootLoaded = setInterval(function() { + var shadowRoot = document.querySelector('gradio-app').shadowRoot; + if(! shadowRoot) return false; + var shadowTextArea = shadowRoot.querySelectorAll('#txt2img_prompt > label > textarea'); + if(shadowTextArea.length < 1) return false; - clearInterval(shadowRootLoaded); + clearInterval(shadowRootLoaded); - document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_prompt').onkeyup = checkBrackets; - document.querySelector('gradio-app').shadowRoot.querySelector('#txt2img_neg_prompt').onkeyup = checkBrackets; - document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_prompt').onkeyup = checkBrackets; - document.querySelector('gradio-app').shadowRoot.querySelector('#img2img_neg_prompt').onkeyup = checkBrackets; + setupBracketChecking('txt2img_prompt', 'txt2img_token_counter') + setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter') + setupBracketChecking('img2img_prompt', 'imgimg_token_counter') + setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter') }, 1000); diff --git a/javascript/ui.js b/javascript/ui.js index 37788a3e..3ba90ca8 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -230,14 +230,26 @@ onUiUpdate(function(){ json_elem.parentElement.style.display="none" - if (!txt2img_textarea) { - txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); - txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); - } - if (!img2img_textarea) { - img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); - img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); - } + function registerTextarea(id, id_counter, id_button){ + var prompt = gradioApp().getElementById(id) + var counter = gradioApp().getElementById(id_counter) + var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); + + if(counter.parentElement == prompt.parentElement){ + return + } + + prompt.parentElement.insertBefore(counter, prompt) + counter.classList.add("token-counter") + prompt.parentElement.style.position = "relative" + + textarea.addEventListener("input", () => update_token_counter(id_button)); + } + + registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button') + registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button') + registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button') + registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button') show_all_pages = gradioApp().getElementById('settings_show_all_pages') settings_tabs = gradioApp().querySelector('#settings div') @@ -249,6 +261,7 @@ onUiUpdate(function(){ }) } } + }) diff --git a/modules/ui.py b/modules/ui.py index eb45a128..06c11848 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -335,28 +335,23 @@ def update_token_counter(text, steps): flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules) prompts = [prompt_text for step, prompt_text in flat_prompts] token_count, max_length = max([model_hijack.get_prompt_lengths(prompt) for prompt in prompts], key=lambda args: args[0]) - style_class = ' class="red"' if (token_count > max_length) else "" - return f"{token_count}/{max_length}" + return f"{token_count}/{max_length}" def create_toprow(is_img2img): id_part = "img2img" if is_img2img else "txt2img" - with gr.Row(elem_id="toprow"): - with gr.Column(scale=6): + with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): + with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, - placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)") with gr.Row(): with gr.Column(scale=80): with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, - placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" - ) + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") with gr.Column(scale=1, elem_id="roll_col"): paste = gr.Button(value=paste_symbol, elem_id="paste") @@ -365,6 +360,8 @@ def create_toprow(is_img2img): clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") + negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") clear_prompt_button.click( fn=lambda *x: x, @@ -402,7 +399,7 @@ def create_toprow(is_img2img): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button def setup_progressbar(*args, **kwargs): @@ -619,7 +616,7 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) @@ -795,12 +792,13 @@ def create_ui(): ] token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) + negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) @@ -1064,6 +1062,7 @@ def create_ui(): ) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) + negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) img2img_paste_fields = [ (img2img_prompt, "Prompt"), diff --git a/style.css b/style.css index c10e32a1..994932de 100644 --- a/style.css +++ b/style.css @@ -2,12 +2,26 @@ max-width: 100%; } -#txt2img_token_counter { - height: 0px; +.token-counter{ + position: absolute; + display: inline-block; + right: 2em; + min-width: 0 !important; + width: auto; + z-index: 100; +} + +.token-counter.error span{ + box-shadow: 0 0 0.0 0.3em rgba(255,0,0,0.15), inset 0 0 0.6em rgba(255,0,0,0.075); + border: 2px solid rgba(255,0,0,0.4) !important; } -#img2img_token_counter { - height: 0px; +.token-counter div{ + display: inline; +} + +.token-counter span{ + padding: 0.1em 0.75em; } #sh{ @@ -113,7 +127,7 @@ #roll_col{ min-width: unset !important; flex-grow: 0 !important; - padding: 0.4em 0; + padding: 0 1em 0 0; gap: 0; } @@ -160,16 +174,6 @@ margin-bottom: 0; } -#toprow div.gr-box, #toprow div.gr-form{ - border: none; - gap: 0; - background: transparent; - box-shadow: none; -} -#toprow div{ - gap: 0; -} - #resize_mode{ flex: 1.5; } @@ -706,6 +710,14 @@ footer { opacity: 0.5; } +[id*='_prompt_container']{ + gap: 0; +} + +[id*='_prompt_container'] > div{ + margin: -0.4em 0 0 0; +} + .gr-compact { border: none; } @@ -715,8 +727,11 @@ footer { margin-left: 0.8em; } +.gr-compact{ + overflow: visible; +} + .gr-compact > *{ - margin-top: 0.5em !important; } .gr-compact .gr-block, .gr-compact .gr-form{ -- cgit v1.2.3 From 7d3fb5cb03cc8520a32b4f56509f2e13e36911bd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 20 Jan 2023 12:12:02 +0300 Subject: add margin to interrogate column in img2img UI --- style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/style.css b/style.css index 994932de..3a515ebd 100644 --- a/style.css +++ b/style.css @@ -145,6 +145,7 @@ #interrogate_col{ min-width: 0 !important; max-width: 8em !important; + margin-right: 1em; } #interrogate, #deepbooru{ margin: 0em 0.25em 0.9em 0.25em; -- cgit v1.2.3 From e33cace2c2074ef342d027c1f31ffc4b3c3e877e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 20 Jan 2023 12:19:30 +0300 Subject: fix ctrl+up/down that stopped working --- javascript/edit-attention.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index ccc8344a..cec6a530 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -1,6 +1,6 @@ addEventListener('keydown', (event) => { let target = event.originalTarget || event.composedPath()[0]; - if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return; + if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return; if (! (event.metaKey || event.ctrlKey)) return; -- cgit v1.2.3 From e0b6092bc99efe311261a51289dec67cbf4845bc Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Fri, 20 Jan 2023 15:31:27 +0100 Subject: Update webui.sh --- webui.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui.sh b/webui.sh index d5e7b3c5..8cdad22d 100755 --- a/webui.sh +++ b/webui.sh @@ -110,7 +110,7 @@ case "$gpu_info" in ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" - printf "Make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" + printf "Experimental support for Renoir: make sure to have at least 4GB of VRAM and 10GB of RAM or enable cpu mode: --use-cpu all --no-half" printf "\n%s\n" "${delimiter}" ;; *) -- cgit v1.2.3 From 40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 08:36:07 +0300 Subject: extra networks UI rework of hypernets: rather than via settings, hypernets are added directly to prompt as --- html/card-no-preview.png | Bin 0 -> 84440 bytes html/extra-networks-card.html | 11 ++ html/extra-networks-no-cards.html | 8 ++ javascript/extraNetworks.js | 60 ++++++++ javascript/hints.js | 2 + javascript/ui.js | 9 +- modules/api/api.py | 7 +- modules/extra_networks.py | 147 +++++++++++++++++++ modules/extra_networks_hypernet.py | 21 +++ modules/generation_parameters_copypaste.py | 12 +- modules/hypernetworks/hypernetwork.py | 107 +++++++++----- modules/hypernetworks/ui.py | 5 +- modules/processing.py | 24 ++-- modules/sd_hijack_optimizations.py | 10 +- modules/shared.py | 21 ++- modules/textual_inversion/textual_inversion.py | 2 + modules/ui.py | 50 ++++--- modules/ui_components.py | 10 ++ modules/ui_extra_networks.py | 149 +++++++++++++++++++ modules/ui_extra_networks_hypernets.py | 34 +++++ modules/ui_extra_networks_textual_inversion.py | 32 +++++ script.js | 13 +- scripts/xy_grid.py | 29 ---- style.css | 190 +++++++++++++------------ webui.py | 26 +++- 25 files changed, 765 insertions(+), 214 deletions(-) create mode 100644 html/card-no-preview.png create mode 100644 html/extra-networks-card.html create mode 100644 html/extra-networks-no-cards.html create mode 100644 javascript/extraNetworks.js create mode 100644 modules/extra_networks.py create mode 100644 modules/extra_networks_hypernet.py create mode 100644 modules/ui_extra_networks.py create mode 100644 modules/ui_extra_networks_hypernets.py create mode 100644 modules/ui_extra_networks_textual_inversion.py diff --git a/html/card-no-preview.png b/html/card-no-preview.png new file mode 100644 index 00000000..e2beb269 Binary files /dev/null and b/html/card-no-preview.png differ diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html new file mode 100644 index 00000000..7314b063 --- /dev/null +++ b/html/extra-networks-card.html @@ -0,0 +1,11 @@ +
+
+
+ +
+ {name} +
+
+ diff --git a/html/extra-networks-no-cards.html b/html/extra-networks-no-cards.html new file mode 100644 index 00000000..389358d6 --- /dev/null +++ b/html/extra-networks-no-cards.html @@ -0,0 +1,8 @@ +
+

Nothing here. Add some content to the following directories:

+ +
    +{dirs} +
+
+ diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js new file mode 100644 index 00000000..71e522d1 --- /dev/null +++ b/javascript/extraNetworks.js @@ -0,0 +1,60 @@ + +function setupExtraNetworksForTab(tabname){ + gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') + + gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh')) + gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close')) +} + +var activePromptTextarea = null; +var activePositivePromptTextarea = null; + +function setupExtraNetworks(){ + setupExtraNetworksForTab('txt2img') + setupExtraNetworksForTab('img2img') + + function registerPrompt(id, isNegative){ + var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); + + if (activePromptTextarea == null){ + activePromptTextarea = textarea + } + if (activePositivePromptTextarea == null && ! isNegative){ + activePositivePromptTextarea = textarea + } + + textarea.addEventListener("focus", function(){ + activePromptTextarea = textarea; + if(! isNegative) activePositivePromptTextarea = textarea; + }); + } + + registerPrompt('txt2img_prompt') + registerPrompt('txt2img_neg_prompt', true) + registerPrompt('img2img_prompt') + registerPrompt('img2img_neg_prompt', true) +} + +onUiLoaded(setupExtraNetworks) + +function cardClicked(textToAdd, allowNegativePrompt){ + textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea + + textarea.value = textarea.value + " " + textToAdd + updateInput(textarea) + + return false +} + +function saveCardPreview(event, tabname, filename){ + textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea') + button = gradioApp().getElementById(tabname + '_save_preview') + + textarea.value = filename + updateInput(textarea) + + button.click() + + event.stopPropagation() + event.preventDefault() +} diff --git a/javascript/hints.js b/javascript/hints.js index e746e20d..f4079f96 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -21,6 +21,8 @@ titles = { "\U0001F5D1": "Clear prompt", "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", + "\u{1f3b4}": "Show extra networks", + "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/javascript/ui.js b/javascript/ui.js index 3ba90ca8..a7e75439 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) { return [prompt, negative_prompt] } - - opts = {} onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; @@ -239,11 +237,14 @@ onUiUpdate(function(){ return } + prompt.parentElement.insertBefore(counter, prompt) counter.classList.add("token-counter") prompt.parentElement.style.position = "relative" - textarea.addEventListener("input", () => update_token_counter(id_button)); + textarea.addEventListener("input", function(){ + update_token_counter(id_button); + }); } registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button') @@ -261,10 +262,8 @@ onUiUpdate(function(){ }) } } - }) - onOptionsChanged(function(){ elem = gradioApp().getElementById('sd_checkpoint_hash') sd_checkpoint_hash = opts.sd_checkpoint_hash || "" diff --git a/modules/api/api.py b/modules/api/api.py index 9814bbc2..2c371e6e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -480,7 +480,7 @@ class Api: def train_hypernetwork(self, args: dict): try: shared.state.begin() - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -491,16 +491,15 @@ class Api: except Exception as e: error = e finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) except AssertionError as msg: shared.state.end() - return TrainResponse(info = "train embedding error: {error}".format(error = error)) + return TrainResponse(info="train embedding error: {error}".format(error=error)) def get_memory(self): try: diff --git a/modules/extra_networks.py b/modules/extra_networks.py new file mode 100644 index 00000000..1978673d --- /dev/null +++ b/modules/extra_networks.py @@ -0,0 +1,147 @@ +import re +from collections import defaultdict + +from modules import errors + +extra_network_registry = {} + + +def initialize(): + extra_network_registry.clear() + + +def register_extra_network(extra_network): + extra_network_registry[extra_network.name] = extra_network + + +class ExtraNetworkParams: + def __init__(self, items=None): + self.items = items or [] + + +class ExtraNetwork: + def __init__(self, name): + self.name = name + + def activate(self, p, params_list): + """ + Called by processing on every run. Whatever the extra network is meant to do should be activated here. + Passes arguments related to this extra network in params_list. + User passes arguments by specifying this in his prompt: + + + + Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments + separated by colon. + + Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - + in this case, all effects of this extra networks should be disabled. + + Can be called multiple times before deactivate() - each new call should override the previous call completely. + + For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is: + + > "1girl, " + + params_list will be: + + [ + ExtraNetworkParams(items=["agm", "1.1"]), + ExtraNetworkParams(items=["ray"]) + ] + + """ + raise NotImplementedError + + def deactivate(self, p): + """ + Called at the end of processing for housekeeping. No need to do anything here. + """ + + raise NotImplementedError + + +def activate(p, extra_network_data): + """call activate for extra networks in extra_network_data in specified order, then call + activate for all remaining registered networks with an empty argument list""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + print(f"Skipping unknown extra network: {extra_network_name}") + continue + + try: + extra_network.activate(p, extra_network_args) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.activate(p, []) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name}") + + +def deactivate(p, extra_network_data): + """call deactivate for extra networks in extra_network_data in specified order, then call + deactivate for all remaining registered networks""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating extra network {extra_network_name}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating unmentioned extra network {extra_network_name}") + + +re_extra_net = re.compile(r"<(\w+):([^>]+)>") + + +def parse_prompt(prompt): + res = defaultdict(list) + + def found(m): + name = m.group(1) + args = m.group(2) + + res[name].append(ExtraNetworkParams(items=args.split(":"))) + + return "" + + prompt = re.sub(re_extra_net, found, prompt) + + return prompt, res + + +def parse_prompts(prompts): + res = [] + extra_data = None + + for prompt in prompts: + updated_prompt, parsed_extra_data = parse_prompt(prompt) + + if extra_data is None: + extra_data = parsed_extra_data + + res.append(updated_prompt) + + return res, extra_data + diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py new file mode 100644 index 00000000..6a0c4ba8 --- /dev/null +++ b/modules/extra_networks_hypernet.py @@ -0,0 +1,21 @@ +from modules import extra_networks +from modules.hypernetworks import hypernetwork + + +class ExtraNetworkHypernet(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('hypernet') + + def activate(self, p, params_list): + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + hypernetwork.load_hypernetworks(names, multipliers) + + def deactivate(p, self): + pass diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a381ff59..46e12dc6 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict): from modules import ui settings_map = { - 'sd_hypernetwork': 'Hypernet', - 'sd_hypernetwork_strength': 'Hypernet strength', 'CLIP_stop_at_last_layers': 'Clip skip', 'inpainting_mask_weight': 'Conditional mask weight', 'sd_model_checkpoint': 'Model hash', @@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Clip skip" not in res: res["Clip skip"] = "1" - if "Hypernet strength" not in res: - res["Hypernet strength"] = "1" - - if "Hypernet" in res: - hypernet_name = res["Hypernet"] - hypernet_hash = res.get("Hypernet hash", None) - res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + hypernet = res.get("Hypernet", None) + if hypernet is not None: + res["Prompt"] += f"""""" if "Hires resize-1" not in res: res["Hires resize-1"] = 0 diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74e78582..80a47c79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -25,7 +25,6 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} class HypernetworkModule(torch.nn.Module): - multiplier = 1.0 activation_dict = { "linear": torch.nn.Identity, "relu": torch.nn.ReLU, @@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module): add_layer_norm=False, activate_output=False, dropout_structure=None): super().__init__() + self.multiplier = 1.0 + assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" @@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x def forward(self, x): - return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1) + return x + self.linear(x) * (self.multiplier if not self.training else 1) def trainables(self): layer_structure = [] @@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module): return layer_structure -def apply_strength(value=None): - HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength - #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check. def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout): if layer_structure is None: @@ -192,6 +190,20 @@ class Hypernetwork: for param in layer.parameters(): param.requires_grad = mode + def to(self, device): + for k, layers in self.layers.items(): + for layer in layers: + layer.to(device) + + return self + + def set_multiplier(self, multiplier): + for k, layers in self.layers.items(): + for layer in layers: + layer.multiplier = multiplier + + return self + def eval(self): for k, layers in self.layers.items(): for layer in layers: @@ -269,11 +281,13 @@ class Hypernetwork: self.optimizer_state_dict = None if self.optimizer_state_dict: self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') - print("Loaded existing optimizer from checkpoint") - print(f"Optimizer name is {self.optimizer_name}") + if shared.opts.print_hypernet_extra: + print("Loaded existing optimizer from checkpoint") + print(f"Optimizer name is {self.optimizer_name}") else: self.optimizer_name = "AdamW" - print("No saved optimizer exists in checkpoint") + if shared.opts.print_hypernet_extra: + print("No saved optimizer exists in checkpoint") for size, sd in state_dict.items(): if type(size) == int: @@ -306,23 +320,43 @@ def list_hypernetworks(path): return res -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - # Prevent any file named "None.pt" from being loaded. - if path is not None and filename != "None": - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) +def load_hypernetwork(name): + path = shared.hypernetworks.get(name, None) - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print("Unloading hypernetwork") + if path is None: + return None + + hypernetwork = Hypernetwork() + + try: + hypernetwork.load(path) + except Exception: + print(f"Error loading hypernetwork {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + return hypernetwork + + +def load_hypernetworks(names, multipliers=None): + already_loaded = {} + + for hypernetwork in shared.loaded_hypernetworks: + if hypernetwork.name in names: + already_loaded[hypernetwork.name] = hypernetwork - shared.loaded_hypernetwork = None + shared.loaded_hypernetworks.clear() + + for i, name in enumerate(names): + hypernetwork = already_loaded.get(name, None) + if hypernetwork is None: + hypernetwork = load_hypernetwork(name) + + if hypernetwork is None: + continue + + hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0) + shared.loaded_hypernetworks.append(hypernetwork) def find_closest_hypernetwork_name(search: str): @@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str): return applicable[0] -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) +def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) if hypernetwork_layers is None: - return context, context + return context_k, context_v if layer is not None: layer.hyper_k = hypernetwork_layers[0] layer.hyper_v = hypernetwork_layers[1] - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) + context_k = hypernetwork_layers[0](context_k) + context_v = hypernetwork_layers[1](context_v) + return context_k, context_v + + +def apply_hypernetworks(hypernetworks, context, layer=None): + context_k = context + context_v = context + for hypernetwork in hypernetworks: + context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer) + return context_k, context_v @@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) + context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self) k = self.to_k(context_k) v = self.to_v(context_v) @@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi template_file = template_file.path path = shared.hypernetworks.get(hypernetwork_name, None) - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) + hypernetwork = Hypernetwork() + hypernetwork.load(path) + shared.loaded_hypernetworks = [hypernetwork] shared.state.job = "train-hypernetwork" shared.state.textinfo = "Initializing hypernetwork training..." @@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi else: images_dir = None - hypernetwork = shared.loaded_hypernetwork checkpoint = sd_models.select_checkpoint() initial_step = hypernetwork.step or 0 diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 81e3f519..76599f5a 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"] keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) + def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) @@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(*args): - - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' @@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception: raise finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) sd_hijack.apply_optimizations() diff --git a/modules/processing.py b/modules/processing.py index a3e9f709..b5deeacf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), - "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()), - "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), @@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': - shared.reload_hypernetworks() # make onchange call for changing hypernet if k == 'sd_model_checkpoint': - sd_models.reload_model_weights() # make onchange call for changing SD model + sd_models.reload_model_weights() if k == 'sd_vae': - sd_vae.reload_vae_weights() # make onchange call for changing VAE + sd_vae.reload_vae_weights() res = process_images_inner(p) @@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': shared.reload_hypernetworks() - if k == 'sd_model_checkpoint': sd_models.reload_model_weights() - if k == 'sd_vae': sd_vae.reload_vae_weights() + if k == 'sd_model_checkpoint': + sd_models.reload_model_weights() + + if k == 'sd_vae': + sd_vae.reload_vae_weights() return res @@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: cache[0] = (required_prompts, steps) return cache[1] + p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts) + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) + extra_networks.activate(p, extra_network_data) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) @@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + extra_networks.deactivate(p, extra_network_data) devices.torch_gc() res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index cdc63ed7..4fa54329 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) del context, context_k, context_v, x @@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) @@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) * self.scale v = self.to_v(context_v) del context, context_k, context_v, x @@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x @@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) diff --git a/modules/shared.py b/modules/shared.py index 2f366454..c0e11f18 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -23,6 +23,7 @@ demo = None sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file + parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) @@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = {} -loaded_hypernetwork = None +loaded_hypernetworks = [] def reload_hypernetworks(): @@ -153,8 +154,6 @@ def reload_hypernetworks(): global hypernetworks hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) - hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - class State: @@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), - "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), - "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), @@ -661,3 +658,17 @@ mem_mon.start() def listfiles(dirname): filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] return [file for file in filenames if os.path.isfile(file)] + + +def html_path(filename): + return os.path.join(script_path, "html", filename) + + +def html(filename): + path = html_path(filename) + + if os.path.exists(path): + with open(path, encoding="utf8") as file: + return file.read() + + return "" diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5a7be422..4e90f690 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -50,6 +50,7 @@ class Embedding: self.sd_checkpoint = None self.sd_checkpoint_name = None self.optimizer_state_dict = None + self.filename = None def save(self, filename): embedding_data = { @@ -182,6 +183,7 @@ class EmbeddingDatabase: embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] + embedding.filename = path if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) diff --git a/modules/ui.py b/modules/ui.py index 06c11848..d23b2b8e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001F5D1' # 🗑️ +extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): @@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps): try: + text, _ = extra_networks.parse_prompt(text) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) @@ -354,10 +357,10 @@ def create_toprow(is_img2img): negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + paste = ToolButton(value=paste_symbol, elem_id="paste") + clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") @@ -395,11 +398,14 @@ def create_toprow(is_img2img): outputs=[], ) - with gr.Row(): + with gr.Row(elem_id=f"{id_part}_styles_row"): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button + prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply") + save_style = ToolButton(value=save_style_symbol, elem_id="style_create") + + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button def setup_progressbar(*args, **kwargs): @@ -616,11 +622,15 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) + with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img') + with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): for category in ordered_ui_categories(): @@ -794,14 +804,20 @@ def create_ui(): token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) + with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img') + with FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): copy_image_buttons = [] @@ -1064,6 +1080,8 @@ def create_ui(): token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + img2img_paste_fields = [ (img2img_prompt, "Prompt"), (img2img_negative_prompt, "Negative prompt"), @@ -1666,10 +1684,8 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") + with gr.TabItem("Licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.Button(value="Show all pages", elem_id="settings_show_all_pages") @@ -1756,11 +1772,9 @@ def create_ui(): if os.path.exists(os.path.join(script_path, "notification.mp3")): audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") + footer = shared.html("footer.html") + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) settings_submit.click( diff --git a/modules/ui_components.py b/modules/ui_components.py index 97acff06..46324425 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent): return "button" +class ToolButtonTop(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, with extra margin at top, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool-top", **kwargs) + + def get_block_name(self): + return "button" + + class FormRow(gr.Row, gr.components.FormComponent): """Same as gr.Row but fits inside gradio forms""" diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py new file mode 100644 index 00000000..253e90f7 --- /dev/null +++ b/modules/ui_extra_networks.py @@ -0,0 +1,149 @@ +import os.path + +from modules import shared +import gradio as gr +import json + +from modules.generation_parameters_copypaste import image_from_url_text + +extra_pages = [] + + +def register_page(page): + """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions""" + + extra_pages.append(page) + + +class ExtraNetworksPage: + def __init__(self, title): + self.title = title + self.card_page = shared.html("extra-networks-card.html") + self.allow_negative_prompt = False + + def refresh(self): + pass + + def create_html(self, tabname): + items_html = '' + + for item in self.list_items(): + items_html += self.create_html_for_item(item, tabname) + + if items_html == '': + dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) + items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) + + res = "
    " + items_html + "
    " + + return res + + def list_items(self): + raise NotImplementedError() + + def allowed_directories_for_previews(self): + return [] + + def create_html_for_item(self, item, tabname): + preview = item.get("preview", None) + + args = { + "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', + "prompt": json.dumps(item["prompt"]), + "tabname": json.dumps(tabname), + "local_preview": json.dumps(item["local_preview"]), + "name": item["name"], + "allow_negative_prompt": "true" if self.allow_negative_prompt else "false", + } + + return self.card_page.format(**args) + + +def intialize(): + extra_pages.clear() + + +class ExtraNetworksUi: + def __init__(self): + self.pages = None + self.stored_extra_pages = None + + self.button_save_preview = None + self.preview_target_filename = None + + self.tabname = None + + +def create_ui(container, button, tabname): + ui = ExtraNetworksUi() + ui.pages = [] + ui.stored_extra_pages = extra_pages.copy() + ui.tabname = tabname + + with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + + for page in ui.stored_extra_pages: + with gr.Tab(page.title): + page_elem = gr.HTML(page.create_html(ui.tabname)) + ui.pages.append(page_elem) + + ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + + button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) + button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) + + def refresh(): + res = [] + + for pg in ui.stored_extra_pages: + pg.refresh() + res.append(pg.create_html(ui.tabname)) + + return res + + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + + return ui + + +def path_is_parent(parent_path, child_path): + parent_path = os.path.abspath(parent_path) + child_path = os.path.abspath(child_path) + + return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path]) + + +def setup_ui(ui, gallery): + def save_preview(index, images, filename): + if len(images) == 0: + print("There is no image in gallery to save as a preview.") + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + index = int(index) + index = 0 if index < 0 else index + index = len(images) - 1 if index >= len(images) else index + + img_info = images[index if index >= 0 else 0] + image = image_from_url_text(img_info) + + is_allowed = False + for extra_page in ui.stored_extra_pages: + if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]): + is_allowed = True + break + + assert is_allowed, f'writing to {filename} is not allowed' + + image.save(filename) + + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + ui.button_save_preview.click( + fn=save_preview, + _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + outputs=[*ui.pages] + ) diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py new file mode 100644 index 00000000..312dbaf0 --- /dev/null +++ b/modules/ui_extra_networks_hypernets.py @@ -0,0 +1,34 @@ +import os + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Hypernetworks') + + def refresh(self): + shared.reload_hypernetworks() + + def list_items(self): + for name, path in shared.hypernetworks.items(): + path, ext = os.path.splitext(path) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "prompt": f"", + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.hypernetwork_dir] + diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py new file mode 100644 index 00000000..e4a6e3bf --- /dev/null +++ b/modules/ui_extra_networks_textual_inversion.py @@ -0,0 +1,32 @@ +import os + +from modules import ui_extra_networks, sd_hijack + + +class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Textual Inversion') + self.allow_negative_prompt = True + + def refresh(self): + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + + def list_items(self): + for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): + path, ext = os.path.splitext(embedding.filename) + preview_file = path + ".preview.png" + + preview = None + if os.path.isfile(preview_file): + preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file)) + + yield { + "name": embedding.name, + "filename": embedding.filename, + "preview": preview, + "prompt": embedding.name, + "local_preview": path + ".preview.png", + } + + def allowed_directories_for_previews(self): + return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/script.js b/script.js index 3345e32b..97e0bfcf 100644 --- a/script.js +++ b/script.js @@ -13,6 +13,7 @@ function get_uiCurrentTabContent() { } uiUpdateCallbacks = [] +uiLoadedCallbacks = [] uiTabChangeCallbacks = [] optionsChangedCallbacks = [] let uiCurrentTab = null @@ -20,6 +21,9 @@ let uiCurrentTab = null function onUiUpdate(callback){ uiUpdateCallbacks.push(callback) } +function onUiLoaded(callback){ + uiLoadedCallbacks.push(callback) +} function onUiTabChange(callback){ uiTabChangeCallbacks.push(callback) } @@ -38,8 +42,15 @@ function executeCallbacks(queue, m) { queue.forEach(function(x){runCallback(x, m)}) } +var executedOnLoaded = false; + document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ + if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){ + executedOnLoaded = true; + executeCallbacks(uiLoadedCallbacks); + } + executeCallbacks(uiUpdateCallbacks, m); const newTab = get_uiCurrentTab(); if ( newTab && ( newTab !== uiCurrentTab ) ) { @@ -53,7 +64,7 @@ document.addEventListener("DOMContentLoaded", function() { /** * Add a ctrl+enter as a shortcut to start a generation */ - document.addEventListener('keydown', function(e) { +document.addEventListener('keydown', function(e) { var handled = false; if (e.key !== undefined) { if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6629f5d5..b1badec9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,6 @@ import modules.scripts as scripts import gradio as gr from modules import images, paths, sd_samplers, processing, sd_models, sd_vae -from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs): raise RuntimeError(f"Unknown checkpoint: {x}") -def apply_hypernetwork(p, x, xs): - if x.lower() in ["", "none"]: - name = None - else: - name = hypernetwork.find_closest_hypernetwork_name(x) - if not name: - raise RuntimeError(f"Unknown hypernetwork: {x}") - hypernetwork.load_hypernetwork(name) - - -def apply_hypernetwork_strength(p, x, xs): - hypernetwork.apply_strength(x) - - -def confirm_hypernetworks(p, xs): - for x in xs: - if x.lower() in ["", "none"]: - continue - if not hypernetwork.find_closest_hypernetwork_name(x): - raise RuntimeError(f"Unknown hypernetwork: {x}") - - def apply_clip_skip(p, x, xs): opts.data["CLIP_stop_at_last_layers"] = x @@ -208,8 +185,6 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)), - AxisOption("Hypernet str.", float, apply_hypernetwork_strength), AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma min", float, apply_field("s_tmin")), AxisOption("Sigma max", float, apply_field("s_tmax")), @@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ class SharedSettingsStackHelper(object): def __enter__(self): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers - self.hypernetwork = opts.sd_hypernetwork self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): @@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object): modules.sd_models.reload_model_weights() modules.sd_vae.reload_vae_weights() - hypernetwork.load_hypernetwork(self.hypernetwork) - hypernetwork.apply_strength() - opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers diff --git a/style.css b/style.css index 3a515ebd..5e8bc2ca 100644 --- a/style.css +++ b/style.css @@ -132,13 +132,6 @@ } #roll_col > button { - min-width: 2em; - min-height: 2em; - max-width: 2em; - max-height: 2em; - flex-grow: 0; - padding-left: 0.25em; - padding-right: 0.25em; margin: 0.1em 0; } @@ -146,9 +139,10 @@ min-width: 0 !important; max-width: 8em !important; margin-right: 1em; + gap: 0; } #interrogate, #deepbooru{ - margin: 0em 0.25em 0.9em 0.25em; + margin: 0em 0.25em 0.5em 0.25em; min-width: 8em; max-width: 8em; } @@ -157,8 +151,17 @@ min-width: 8em !important; } +#txt2img_styles_row, #img2img_styles_row{ + gap: 0.25em; + margin-top: 0.5em; +} + +#txt2img_styles_row > button, #img2img_styles_row > button{ + margin: 0; +} + #txt2img_styles, #img2img_styles{ - margin-top: 1em; + padding: 0; } #txt2img_styles ul, #img2img_styles ul{ @@ -635,17 +638,21 @@ canvas[key="mask"] { background-color: rgb(31 41 55 / var(--tw-bg-opacity)); } -.gr-button-tool{ +.gr-button-tool, .gr-button-tool-top{ max-width: 2.5em; min-width: 2.5em !important; height: 2.4em; - margin: 1.6em 0.7em 0.55em 0; } -#tab_modelmerger .gr-button-tool{ +.gr-button-tool{ margin: 0.6em 0em 0.55em 0; } +.gr-button-tool-top, #settings .gr-button-tool{ + margin: 1.6em 0.7em 0.55em 0; +} + + #modelmerger_results_container{ margin-top: 1em; overflow: visible; @@ -763,81 +770,88 @@ footer { line-height: 2.4em; } -/* The following handles localization for right-to-left (RTL) languages like Arabic. -The rtl media type will only be activated by the logic in javascript/localization.js. -If you change anything above, you need to make sure it is RTL compliant by just running -your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/. -Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/ -@media rtl { - /* this part was added manually */ - :host { - direction: rtl; - } - select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress { - direction: ltr; - } - #script_list > label > select, - #x_type > label > select, - #y_type > label > select { - direction: rtl; - } - .gr-radio, .gr-checkbox{ - margin-left: 0.25em; - } +#txt2img_extra_networks, #img2img_extra_networks{ + margin-top: -1em; +} - /* automatically generated with few manual modifications */ - .performance .time { - margin-right: unset; - margin-left: 0; - } - .justify-center.overflow-x-scroll { - justify-content: right; - } - .justify-center.overflow-x-scroll button:first-of-type { - margin-left: unset; - margin-right: auto; - } - .justify-center.overflow-x-scroll button:last-of-type { - margin-right: unset; - margin-left: auto; - } - #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ - margin-right: unset; - margin-left: 8em; - } - #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ - right: unset; - left: 0; - } - .progressDiv .progress{ - padding: 0 0 0 8px; - text-align: left; - } - #lightboxModal{ - left: unset; - right: 0; - } - .modalPrev, .modalNext{ - border-radius: 3px 0 0 3px; - } - .modalNext { - right: unset; - left: 0; - border-radius: 0 3px 3px 0; - } - #imageARPreview{ - left:unset; - right:0px; - } - #txt2img_skip, #img2img_skip{ - right: unset; - left: 0px; - } - #context-menu{ - box-shadow:-1px 1px 2px #CE6400; - } - .gr-box > div > div > input.gr-text-input{ - right: unset; - left: 0.5em; - } +.extra-networks > div > [id *= '_extra_']{ + margin: 0.3em; } + +.extra-network-cards .nocards{ + margin: 1.25em 0.5em 0.5em 0.5em; +} + +.extra-network-cards .nocards h1{ + font-size: 1.5em; + margin-bottom: 1em; +} + +.extra-network-cards .nocards li{ + margin-left: 0.5em; +} + +.extra-network-cards .card{ + display: inline-block; + margin: 0.5em; + width: 16em; + height: 24em; + box-shadow: 0 0 5px rgba(128, 128, 128, 0.5); + border-radius: 0.2em; + position: relative; + + background-size: auto 100%; + background-position: center; + overflow: hidden; + cursor: pointer; + + background-image: url('./file=html/card-no-preview.png') +} + +.extra-network-cards .card:hover{ + box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35); +} + +.extra-network-cards .card .actions .additional{ + display: none; +} + +.extra-network-cards .card .actions{ + position: absolute; + bottom: 0; + left: 0; + right: 0; + padding: 0.5em; + color: white; + background: rgba(0,0,0,0.5); + box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5); + text-shadow: 0 0 0.2em black; +} + +.extra-network-cards .card .actions:hover{ + box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important; +} + +.extra-network-cards .card .actions .name{ + font-size: 1.7em; + font-weight: bold; + line-break: anywhere; +} + +.extra-network-cards .card .actions:hover .additional{ + display: block; +} + +.extra-network-cards .card ul{ + margin: 0.25em 0 0.75em 0.25em; + cursor: unset; +} + +.extra-network-cards .card ul a{ + cursor: pointer; +} + +.extra-network-cards .card ul a:hover{ + color: red; +} + diff --git a/webui.py b/webui.py index 865a7300..e8dd822a 100644 --- a/webui.py +++ b/webui.py @@ -9,16 +9,18 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook, errors +from modules import import_hook, errors, extra_networks +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path import torch + # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) -from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -84,10 +86,17 @@ def initialize(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks())) - shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: try: @@ -209,6 +218,15 @@ def webui(): modules.sd_models.list_models() + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if __name__ == "__main__": if cmd_opts.nowebui: -- cgit v1.2.3 From 6d805b669e86233432f56ee1892d062103abe501 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 09:14:27 +0300 Subject: make CLIP interrogator download original text files if the directory does not exist remove random artist built-in extension (to re-added as a normal extension on demand) remove artists.csv (but what does it mean????????????????????) make interrogate buttons show Loading... when you click them --- README.md | 1 - artists.csv | 3041 -------------------- .../roll-artist/scripts/roll-artist.py | 50 - javascript/hints.js | 1 - modules/api/api.py | 8 - modules/artists.py | 25 - modules/interrogate.py | 55 +- modules/shared.py | 5 - modules/ui.py | 11 +- 9 files changed, 46 insertions(+), 3151 deletions(-) delete mode 100644 artists.csv delete mode 100644 extensions-builtin/roll-artist/scripts/roll-artist.py delete mode 100644 modules/artists.py diff --git a/README.md b/README.md index d783fdf0..1ac794e8 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,6 @@ A browser interface based on Gradio library for Stable Diffusion. - Running arbitrary python code from UI (must run with --allow-code to enable) - Mouseover hints for most UI elements - Possible to change defaults/mix/max/step values for UI elements via text config -- Random artist button - Tiling support, a checkbox to create images that can be tiled like textures - Progress bar and live image generation preview - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image diff --git a/artists.csv b/artists.csv deleted file mode 100644 index 1a61ed88..00000000 --- a/artists.csv +++ /dev/null @@ -1,3041 +0,0 @@ -artist,score,category -Peter Max,0.99715996,weird -Roy Lichtenstein,0.98272276,cartoon -Romero Britto,0.9498342,scribbles -Keith Haring,0.9431302,weird -Hiroshige,0.93995106,ukioe -Joan Miró,0.9169429,scribbles -Jean-Michel Basquiat,0.90080947,scribbles -Katsushika Hokusai,0.8887236,ukioe -Paul Klee,0.8868682,scribbles -Marc Chagall,0.8868168,scribbles -Karl Schmidt-Rottluff,0.88444495,scribbles -Howard Hodgkin,0.8808578,scribbles -Jean Metzinger,0.88056004,scribbles -Alma Thomas,0.87658304,weird -Rufino Tamayo,0.8749848,scribbles -Utagawa Hiroshige,0.8728796,ukioe -Chagall,0.8718535,scribbles -Harumi Hironaka,0.86914605,scribbles -Hans Hofmann,0.8686159,scribbles -Kawanabe Kyōsai,0.86612236,ukioe -Andy Warhol,0.8654825,scribbles -Barbara Takenaga,0.86223894,scribbles -Tatsuro Kiuchi,0.8597267,cartoon -Vincent Van Gogh,0.85538065,scribbles -Wassily Kandinsky,0.85490596,scribbles -Georges Seurat,0.8534801,scribbles -Karel Appel,0.8529153,scribbles -Sonia Delaunay,0.8506156,scribbles -Hokusai,0.85046995,ukioe -Eduardo Kobra,0.85036755,weird -Fra Angelico,0.84984255,fineart -Milton Avery,0.849746,scribbles -David Hockney,0.8496144,scribbles -Hiroshi Nagai,0.847129,cartoon -Aristarkh Lentulov,0.846537,scribbles -Lyonel Feininger,0.84573764,scribbles -Mary Blair,0.845709,scribbles -Ellsworth Kelly,0.8455428,scribbles -Jun Kaneko,0.8448367,scribbles -Roz Chast,0.8432013,weird -Ida Rentoul Outhwaite,0.84275174,scribbles -Robert Motherwell,0.8409468,scribbles -Garry Winogrand,0.83994275,black-white -Andrei Rublev,0.83950496,fineart -Alexander Calder,0.83832693,scribbles -Tomokazu Matsuyama,0.8376121,scribbles -August Macke,0.8362022,scribbles -Kazimir Malevich,0.8356527,scribbles -Richard Scarry,0.83554685,scribbles -Victor Vasarely,0.8335438,scribbles -Kitagawa Utamaro,0.83333457,ukioe -Matt Bors,0.83252287,scribbles -Emil Nolde,0.8323225,scribbles -Patrick Caulfield,0.8322225,scribbles -Charles Blackman,0.83200824,scribbles -Peter Doig,0.83111644,scribbles -Alexej von Jawlensky,0.8308932,scribbles -Rumiko Takahashi,0.8301817,anime -Eileen Agar,0.82945526,scribbles -Ernst Ludwig Kirchner,0.82756275,scribbles -Nicolas Delort,0.8261329,scribbles -Marsden Hartley,0.8250993,scribbles -Keith Negley,0.8212553,scribbles -Jamini Roy,0.8212199,scribbles -Quentin Blake,0.82115215,scribbles -Andy Kehoe,0.82063186,cartoon -George barbier,0.82046914,fineart -Frans Masereel,0.81997275,scribbles -Umberto Boccioni,0.81921184,scribbles -Conrad Roset,0.8190752,cartoon -Paul Ranson,0.81903255,scribbles -Yayoi Kusama,0.81886625,weird -Tomi Ungerer,0.81848705,scribbles -Saul Steinberg,0.81778854,scribbles -Jon Klassen,0.81773067,scribbles -W.W. Denslow,0.81708044,fineart -Helen Frankenthaler,0.81704986,scribbles -Jean Jullien,0.816437,scribbles -Brett Whiteley,0.81601924,scribbles -Giotto Di Bondone,0.81427747,fineart -Takashi Murakami,0.81338763,weird -Howard Finster,0.81333554,scribbles -Eduardo Paolozzi,0.81312317,scribbles -Charles Rennie Mackintosh,0.81297064,scribbles -Brandon Mably,0.8128239,weird -Rebecca Louise Law,0.81214285,weird -Victo Ngai,0.81195843,cartoon -Hanabusa Itchō II,0.81187993,ukioe -Edmund Dulac,0.81104875,scribbles -Ben Shahn,0.8104582,scribbles -Howard Arkley,0.8103746,scribbles -Wilfredo Lam,0.8096211,scribbles -Michael Deforge,0.8095954,scribbles -John Hoyland,0.8094592,fineart -Francesco Clemente,0.8090387,scribbles -Leonetto Cappiello,0.8087691,scribbles -Norman Ackroyd,0.80788493,scribbles -Bhupen Khakhar,0.8077607,scribbles -Jeremiah Ketner,0.8075384,cartoon -Chris Ofili,0.8073793,scribbles -Banksy,0.80695426,scribbles -Tom Whalen,0.805867,scribbles -Ernst Wilhelm Nay,0.805295,scribbles -Henri Rousseau,0.8049866,scribbles -Kunisada,0.80493814,ukioe -Naoko Takeuchi,0.80482674,anime -Kaethe Butcher,0.80406916,scribbles -Hasui Kawase,0.8040483,ukioe -Alvin Langdon Coburn,0.8035004,black-white -Stanley Donwood,0.8033054,scribbles -Agnes Martin,0.8028028,scribbles -Osamu Tezuka,0.8005524,cartoon -Frank Stella,0.80049455,scribbles -Dale Chihuly,0.79982775,digipa-high-impact -Evgeni Gordiets,0.79967916,scribbles -Janek Sedlar,0.7993992,fineart -Alasdair Gray,0.7992301,scribbles -Yasuo Kuniyoshi,0.79870003,ukioe -Edward Gorey,0.7984938,scribbles -Johannes Itten,0.798481,scribbles -Cuno Amiet,0.7979497,scribbles -M.C. Escher,0.7976657,scribbles -Albert Irvin,0.79688835,scribbles -Jack Gaughan,0.79443675,scribbles -Ravi Zupa,0.7939542,scribbles -Kay Nielsen,0.79385525,scribbles -Agnolo Gaddi,0.79369193,fineart -Alessandro Gottardo,0.79321593,scribbles -Paul Laffoley,0.79196846,scribbles -Giovanni Battista Piranesi,0.79111177,fineart -Adrian Tomine,0.79109013,scribbles -Adolph Gottlieb,0.79061794,scribbles -Milton Caniff,0.7905358,cartoon -Philip Guston,0.78994095,scribbles -Debbie Criswell,0.7895031,cartoon -Alice Pasquini,0.78949904,cartoon -Johannes Vermeer,0.78931487,fineart -Lisa Frank,0.7892591,cartoon -Patrick Heron,0.78889126,scribbles -Mikhail Nesterov,0.78814346,fineart -Cézanne,0.7879481,scribbles -Tristan Eaton,0.787513,scribbles -Jillian Tamaki,0.7868066,scribbles -Takato Yamamoto,0.78460765,ukioe -Martiros Saryan,0.7844924,scribbles -Emil Orlik,0.7842625,scribbles -Armand Guillaumin,0.7840431,scribbles -Jane Newland,0.7837676,scribbles -Paul Cézanne,0.78368753,scribbles -Tove Jansson,0.78356475,scribbles -Guido Crepax,0.7835321,cartoon -OSGEMEOS,0.7829088,weird -Albert Watson,0.48901254,digipa-med-impact -Emory Douglas,0.78179604,scribbles -Chris Van Allsburg,0.66413003,fineart -Ohara Koson,0.78132576,ukioe -Nicolas de Stael,0.7802779,scribbles -Aubrey Beardsley,0.77970016,scribbles -Hishikawa Moronobu,0.7794119,ukioe -Alfred Wallis,0.77926695,scribbles -Friedensreich Hundertwasser,0.7791805,scribbles -Eyvind Earle,0.7788089,scribbles -Giotto,0.7785216,fineart -Simone Martini,0.77843,fineart -Ivan Bilibin,0.77720606,fineart -Karl Blossfeldt,0.77652574,black-white -Duy Huynh,0.77634746,scribbles -Giovanni da Udina,0.7763063,fineart -Henri-Edmond Cross,0.7762994,fineart -Barry McGee,0.77618384,scribbles -William Kentridge,0.77615225,scribbles -Alexander Archipenko,0.7759824,scribbles -Jaume Plensa,0.7756799,weird -Bill Jacklin,0.77504414,fineart -Alberto Vargas,0.7747376,cartoon -Jean Dubuffet,0.7744374,scribbles -Eugène Grasset,0.7741958,fineart -Arthur Rackham,0.77418125,fineart -Yves Tanguy,0.77380997,scribbles -Elsa Beskow,0.7736908,fineart -Georgia O’Keeffe,0.77368987,scribbles -Georgia O'Keeffe,0.77368987,scribbles -Henri Cartier-Bresson,0.7735415,black-white -Andrea del Verrocchio,0.77307427,fineart -Mark Rothko,0.77294236,scribbles -Bruce Gilden,0.7256681,black-white -Gino Severini,0.77247965,scribbles -Delphin Enjolras,0.5594248,fineart -Alena Aenami,0.77210015,cartoon -Ed Freeman,0.42526615,digipa-low-impact -Apollonia Saintclair,0.7718383,anime -László Moholy-Nagy,0.771497,scribbles -Louis Glackens,0.7713224,fineart -Fang Lijun,0.77097225,fineart -Alfred Kubin,0.74409986,fineart -David Wojnarowicz,0.7705802,scribbles -Tara McPherson,0.77023256,scribbles -Gustav Doré,0.7367536,fineart -Patricia Polacco,0.7696109,scribbles -Norman Bluhm,0.7692634,fineart -Elizabeth Gadd,0.7691194,digipa-high-impact -Gabriele Münter,0.7690926,scribbles -David Inshaw,0.76905304,scribbles -Maurice Sendak,0.7690118,cartoon -Harry Clarke,0.7688428,cartoon -Howardena Pindell,0.7686921,n -Jamie Hewlett,0.7680373,scribbles -Steve Ditko,0.76725733,scribbles -Annie Soudain,0.7671485,scribbles -Albert Gleizes,0.76658314,scribbles -Henry Fuseli,0.69147265,fineart -Alain Laboile,0.67634284,c -Albrecht Altdorfer,0.7663378,fineart -Jack Butler Yeats,0.7661406,fineart -Yue Minjun,0.76583517,scribbles -Art Spiegelman,0.7656343,scribbles -Grete Stern,0.7656276,fineart -Mordecai Ardon,0.7648692,scribbles -Joel Sternfeld,0.76456416,digipa-high-impact -Milton Glaser,0.7641823,scribbles -Eishōsai Chōki,0.7639659,scribbles -Domenico Ghirlandaio,0.76372653,fineart -Alex Timmermans,0.64443207,digipa-high-impact -Andreas Vesalius,0.763446,fineart -Bruce McLean,0.76335883,scribbles -Jacob Lawrence,0.76330304,scribbles -Alex Katz,0.76317835,scribbles -Henri de Toulouse-Lautrec,0.76268333,scribbles -Franz Sedlacek,0.762062,scribbles -Paul Lehr,0.70854837,cartoon -Nicholas Roerich,0.76117516,scribbles -Henri Matisse,0.76110923,scribbles -Colin McCahon,0.76086944,scribbles -Max Dupain,0.6661642,black-white -Stephen Gammell,0.74001735,weird -Alberto Giacometti,0.7596302,scribbles -Goyō Hashiguchi,0.7595048,ukioe -Gustave Doré,0.7018832,fineart -Butcher Billy,0.7593378,cartoon -Pieter de Hooch,0.75916564,fineart -Gaetano Pesce,0.75906265,scribbles -Winsor McCay,0.7589382,scribbles -Claude Cahun,0.7588153,weird -Roger Ballen,0.64683115,black-white -Ellen Gallagher,0.758621,scribbles -Anton Corbijn,0.5550669,digipa-high-impact -Margaret Macdonald Mackintosh,0.75781375,fineart -Franz Kline,0.7576461,scribbles -Cimabue,0.75720495,fineart -André Kertész,0.7319392,black-white -Hans Hartung,0.75718236,scribbles -J. J. Grandville,0.7321584,fineart -David Octavius Hill,0.6333561,digipa-high-impact -teamLab,0.7566472,digipa-high-impact -Paul Gauguin,0.75635266,scribbles -Etel Adnan,0.75631833,scribbles -Barbara Kruger,0.7562784,scribbles -Franz Marc,0.75538874,scribbles -Saul Bass,0.75496316,scribbles -El Lissitzky,0.7549487,scribbles -Thomas Moran,0.6507399,fineart -Claude Monet,0.7541377,fineart -David Young Cameron,0.7541016,scribbles -W. Heath Robinson,0.75374347,cartoon -Yves Klein,0.7536262,fineart -Albert Pinkham Ryder,0.7338848,fineart -Elizabeth Shippen Green,0.7533686,fineart -Robert Stivers,0.5516287,fineart -Emily Kame Kngwarreye,0.7532016,weird -Charline von Heyl,0.753142,scribbles -Frida Kahlo,0.75303876,scribbles -Amy Sillman,0.752921,scribbles -Emperor Huizong of Song,0.7525214,ukioe -Edward Burne-Jones,0.75220466,fineart -Brett Weston,0.6891357,black-white -Charles E. Burchfield,0.75174403,scribbles -Hishida Shunsō,0.751617,fareast -Elaine de Kooning,0.7514996,scribbles -Gary Panter,0.7514598,scribbles -Frederick Hammersley,0.7514268,scribbles -Gustave Dore,0.6735896,fineart -Ephraim Moses Lilien,0.7510494,fineart -Hannah Hoch,0.7509496,scribbles -Shepard Fairey,0.7508583,scribbles -Richard Burlet,0.7506659,scribbles -Bill Brandt,0.6833408,black-white -Herbert List,0.68455493,black-white -Joseph Cornell,0.75023884,nudity -Nathan Wirth,0.6436741,black-white -John Kenn Mortensen,0.74758303,anime -Andre De Dienes,0.5683014,digipa-high-impact -Albert Robida,0.7485741,cartoon -Shintaro Kago,0.7484431,anime -Sidney Nolan,0.74809414,scribbles -Patrice Murciano,0.61973965,fineart -Brian Stelfreeze,0.7478351,scribbles -Francisco De Goya,0.6954584,fineart -William Morris,0.7478111,fineart -Honoré Daumier,0.74767774,scribbles -Hubert Robert,0.6863421,fineart -Marianne von Werefkin,0.7475825,fineart -Edvard Munch,0.74719715,scribbles -Victor Brauner,0.74719006,scribbles -George Inness,0.7470588,fineart -Naoki Urasawa,0.7469665,anime -Kilian Eng,0.7468486,scribbles -Bordalo II,0.7467364,digipa-high-impact -Katsuhiro Otomo,0.746364,anime -Maximilien Luce,0.74609685,fineart -Amy Earles,0.74603415,fineart -Jeanloup Sieff,0.7196009,black-white -William Zorach,0.74574494,scribbles -Pascale Campion,0.74516207,fineart -Dorothy Lathrop,0.74418795,fineart -Sofonisba Anguissola,0.74418664,fineart -Natalia Goncharova,0.74414873,scribbles -August Sander,0.6644566,black-white -Jasper Johns,0.74395454,scribbles -Arthur Dove,0.74383533,scribbles -Darwyn Cooke,0.7435789,scribbles -Leonardo Da Vinci,0.6825216,fineart -Fra Filippo Lippi,0.7433891,fineart -Pierre-Auguste Renoir,0.742464,fineart -Jeff Lemire,0.7422893,scribbles -Al Williamson,0.742113,cartoon -Childe Hassam,0.7418015,fineart -Francisco Goya,0.69522625,fineart -Alphonse Mucha,0.74171394,special -Cleon Peterson,0.74163914,scribbles -J.M.W. Turner,0.65582645,fineart -Walter Crane,0.74146044,fineart -Brassaï,0.6361966,digipa-high-impact -Virgil Finlay,0.74133486,fineart -Fernando Botero,0.7412504,nudity -Ben Nicholson,0.7411573,scribbles -Robert Rauschenberg,0.7410054,fineart -David Wiesner,0.7406237,scribbles -Bartolome Esteban Murillo,0.6933951,fineart -Jean Arp,0.7403873,scribbles -Andre Kertesz,0.7228358,black-white -Simeon Solomon,0.66441345,fineart -Hugh Ferriss,0.72443527,black-white -Agnes Lawrence Pelton,0.73960555,scribbles -Charles Camoin,0.7395686,scribbles -Paul Strand,0.7080332,black-white -Charles Gwathmey,0.7394747,scribbles -Bartolomé Esteban Murillo,0.7011274,fineart -Oskar Kokoschka,0.7392038,scribbles -Bruno Munari,0.73918355,weird -Willem de Kooning,0.73916197,scribbles -Hans Memling,0.7387886,fineart -Chris Mars,0.5861489,digipa-high-impact -Hiroshi Yoshida,0.73787534,ukioe -Hundertwasser,0.7377672,fineart -David Bowie,0.73773724,weird -Ettore Sottsass,0.7376095,digipa-high-impact -Antanas Sutkus,0.7369492,black-white -Leonora Carrington,0.73726475,scribbles -Hieronymus Bosch,0.7369955,scribbles -A. J. Casson,0.73666203,scribbles -Chaim Soutine,0.73662066,scribbles -Artur Bordalo,0.7364549,weird -Thomas Allom,0.68792284,fineart -Louis Comfort Tiffany,0.7363504,fineart -Philippe Druillet,0.7363382,cartoon -Jan Van Eyck,0.7360621,fineart -Sandro Botticelli,0.7359395,fineart -Hieronim Bosch,0.7359308,scribbles -Everett Shinn,0.7355817,fineart -Camille Corot,0.7355603,fineart -Nick Sharratt,0.73470485,scribbles -Fernand Léger,0.7079839,scribbles -Robert S. Duncanson,0.7346282,fineart -Hieronymous Bosch,0.73453265,scribbles -Charles Addams,0.7344034,scribbles -Studio Ghibli,0.73439026,anime -Archibald Motley,0.7343683,scribbles -Anton Fadeev,0.73433846,cartoon -Uemura Shoen,0.7342118,ukioe -Ando Fuchs,0.73406494,black-white -Jessie Willcox Smith,0.73398125,fineart -Alex Garant,0.7333658,scribbles -Lawren Harris,0.73331416,scribbles -Anne Truitt,0.73297834,scribbles -Richard Lindner,0.7328564,scribbles -Sailor Moon,0.73281246,anime -Bridget Bate Tichenor,0.73274165,scribbles -Ralph Steadman,0.7325864,scribbles -Annibale Carracci,0.73251307,fineart -Dürer,0.7324789,fineart -Abigail Larson,0.7319012,cartoon -Bill Traylor,0.73189163,scribbles -Louis Rhead,0.7318623,fineart -David Burliuk,0.731803,scribbles -Camille Pissarro,0.73172396,fineart -Catrin Welz-Stein,0.73117495,scribbles -William Etty,0.6497544,nudity -Pierre Bonnard,0.7310132,scribbles -Benoit B. Mandelbrot,0.5033001,digipa-med-impact -Théodore Géricault,0.692039,fineart -Andy Goldsworthy,0.7307565,digipa-high-impact -Alfred Sisley,0.7306032,fineart -Charles-Francois Daubigny,0.73057353,fineart -Karel Thole,0.7305395,cartoon -Andre Derain,0.73050404,scribbles -Larry Poons,0.73023695,fineart -Beauford Delaney,0.72999024,scribbles -Ruth Bernhard,0.72990334,black-white -David Alfaro Siqueiros,0.7297947,scribbles -Gaugin,0.729636,fineart -Carl Larsson,0.7296195,cartoon -Albrecht Dürer,0.72946966,fineart -Henri De Toulouse Lautrec,0.7294263,cartoon -Shotaro Ishinomori,0.7292093,anime -Hope Gangloff,0.729082,scribbles -Vivian Maier,0.72897506,digipa-high-impact -Alex Andreev,0.6442978,digipa-high-impact -Julie Blackmon,0.72862685,c -Arthur Melville,0.7286146,fineart -Henri Michaux,0.599607,fineart -William Steig,0.7283096,scribbles -Octavio Ocampo,0.72814554,scribbles -Cy Twombly,0.72814107,scribbles -Guy Denning,0.67375445,fineart -Maxfield Parrish,0.7280283,fineart -Randolph Caldecott,0.7279564,fineart -Duccio,0.72795,fineart -Ray Donley,0.5837457,fineart -Hiroshi Sugimoto,0.6497892,digipa-high-impact -Daniela Uhlig,0.4691466,special -Go Nagai,0.72770613,anime -Carlo Crivelli,0.72764605,fineart -Helmut Newton,0.44433144,digipa-low-impact -Josef Albers,0.7061394,scribbles -Henry Moret,0.7274567,fineart -André Masson,0.727404,scribbles -Henri Fantin Latour,0.72732764,fineart -Theo van Rysselberghe,0.7272843,fineart -John Wayne Gacy,0.72686327,scribbles -Carlos Schwabe,0.7267612,fineart -Herbert Bayer,0.7094297,scribbles -Domenichino,0.72667265,fineart -Liam Wong,0.7262276,special -George Caleb Bingham,0.7262154,digipa-high-impact -Gigadō Ashiyuki,0.7261864,fineart -Chaïm Soutine,0.72603923,scribbles -Ary Scheffer,0.64913243,fineart -Rockwell Kent,0.7257272,scribbles -Jean-Paul Riopelle,0.72570604,fineart -Ed Mell,0.6637067,cartoon -Ismail Inceoglu,0.72561014,special -Edgar Degas,0.72538006,fineart -Giorgione,0.7252798,fineart -Charles-François Daubigny,0.7252482,fineart -Arthur Lismer,0.7251765,scribbles -Aaron Siskind,0.4852289,digipa-med-impact -Arkhip Kuindzhi,0.7249981,fineart -Joseph Mallord William Turner,0.6834406,fineart -Dante Gabriel Rossetti,0.7244541,fineart -Ernst Haeckel,0.6660129,fineart -Rebecca Guay,0.72439146,cartoon -Anthony Gerace,0.636678,digipa-high-impact -Martin Kippenberger,0.72418386,scribbles -Diego Giacometti,0.72415763,scribbles -Dmitry Kustanovich,0.7241322,cartoon -Dora Carrington,0.7239633,scribbles -Shusei Nagaoko,0.7238965,anime -Odilon Redon,0.72381747,scribbles -Shohei Otomo,0.7132803,nudity -Barnett Newman,0.7236389,scribbles -Jean Fouquet,0.7235963,fineart -Gustav Klimt,0.72356784,nudity -Francisco Josè de Goya,0.6589663,fineart -Bonnard Pierre,0.72309464,nudity -Brooke Shaden,0.61281693,digipa-high-impact -Mao Hamaguchi,0.7228292,scribbles -Frederick Edwin Church,0.64416,fineart -Asher Brown Durand,0.72264796,fineart -George Baselitz,0.7223453,scribbles -Sam Bosma,0.7223237,fineart -Asaf Hanuka,0.72222745,scribbles -David Teniers the Younger,0.7221168,fineart -Nicola Samori,0.68747556,nudity -Claude Lorrain,0.7217102,fineart -Hermenegildo Anglada Camarasa,0.7214374,nudity -Pablo Picasso,0.72142905,scribbles -Howard Chaykin,0.7213998,cartoon -Ferdinand Hodler,0.7213758,nudity -Farel Dalrymple,0.7213298,fineart -Lyubov Popova,0.7213024,scribbles -Albin Egger-Lienz,0.72120845,fineart -Geertgen tot Sint Jans,0.72107565,fineart -Kate Greenaway,0.72069687,fineart -Louise Bourgeois,0.7206516,fineart -Miriam Schapiro,0.72026414,fineart -Pieter Claesz,0.7200939,fineart -George B. Bridgman,0.5592567,fineart -Piet Mondrian,0.71990657,scribbles -Michelangelo Merisi Da Caravaggio,0.7094674,fineart -Marie Spartali Stillman,0.71986604,fineart -Gertrude Abercrombie,0.7196962,scribbles -Louis Icart,0.7195913,fineart -David Driskell,0.719564,scribbles -Paula Modersohn-Becker,0.7193769,scribbles -George Hurrell,0.57496595,digipa-high-impact -Andrea Mantegna,0.7190254,fineart -Silvestro Lega,0.71891177,fineart -Junji Ito,0.7188978,anime -Jacob Hashimoto,0.7186867,digipa-high-impact -Benjamin West,0.6642946,fineart -David Teniers the Elder,0.7181293,fineart -Roberto Matta,0.71808386,fineart -Chiho Aoshima,0.71801454,anime -Amedeo Modigliani,0.71788836,scribbles -Raja Ravi Varma,0.71788085,fineart -Roberto Ferri,0.538221,nudity -Winslow Homer,0.7176876,fineart -Horace Vernet,0.65729,fineart -Lucas Cranach the Elder,0.71738195,fineart -Godfried Schalcken,0.625893,fineart -Affandi,0.7170285,nudity -Diane Arbus,0.655138,digipa-high-impact -Joseph Ducreux,0.65247905,digipa-high-impact -Berthe Morisot,0.7165984,fineart -Hilma af Klint,0.71643853,scribbles -Filippino Lippi,0.7163017,fineart -Leonid Afremov,0.7163005,fineart -Chris Ware,0.71628594,scribbles -Marius Borgeaud,0.7162446,scribbles -M.W. Kaluta,0.71612585,cartoon -Govert Flinck,0.68975246,fineart -Charles Demuth,0.71605396,scribbles -Coles Phillips,0.7158309,scribbles -Oskar Fischinger,0.6721027,digipa-high-impact -David Teniers III,0.71569765,fineart -Jean Delville,0.7156771,fineart -Antonio Saura,0.7155949,scribbles -Bridget Riley,0.7155669,fineart -Gordon Parks,0.5759978,digipa-high-impact -Anselm Kiefer,0.71514887,scribbles -Remedios Varo,0.7150927,weird -Franz Hegi,0.71495223,scribbles -Kati Horna,0.71486115,black-white -Arshile Gorky,0.71459055,scribbles -David LaChapelle,0.7144903,scribbles -Fritz von Dardel,0.71446383,scribbles -Edward Ruscha,0.71438885,fineart -Blanche Hoschedé Monet,0.7143073,fineart -Alexandre Calame,0.5735474,fineart -Sean Scully,0.714154,fineart -Alexandre Benois,0.7141515,fineart -Sally Mann,0.6534312,black-white -Thomas Eakins,0.7141104,fineart -Arnold Böcklin,0.71407956,fineart -Alfonse Mucha,0.7139052,special -Damien Hirst,0.7136273,scribbles -Lee Krasner,0.71362555,scribbles -Dorothea Lange,0.71361613,black-white -Juan Gris,0.7132987,scribbles -Bernardo Bellotto,0.70720065,fineart -John Martin,0.5376847,fineart -Harriet Backer,0.7131594,fineart -Arnold Newman,0.5736342,digipa-high-impact -Gjon Mili,0.46520913,digipa-low-impact -Asger Jorn,0.7129575,scribbles -Chesley Bonestell,0.6063316,fineart -Agostino Carracci,0.7128167,fineart -Peter Wileman,0.71271706,cartoon -Chen Hongshou,0.71268153,ukioe -Catherine Hyde,0.71266896,scribbles -Andrea Pozzo,0.626546,fineart -Kitty Lange Kielland,0.7125735,fineart -Cornelis Saftleven,0.6684047,fineart -Félix Vallotton,0.71237606,fineart -Albrecht Durer,0.7122327,fineart -Jackson Pollock,0.71222305,scribbles -John Bratby,0.7122171,scribbles -Beksinski,0.71218586,fineart -James Thomas Watts,0.5959548,fineart -Konstantin Korovin,0.71188873,fineart -Gustave Caillebotte,0.71181154,fineart -Dean Ellis,0.50233585,fineart -Friedrich von Amerling,0.6420181,fineart -Christopher Balaskas,0.67935324,special -Alexander Rodchenko,0.67415404,scribbles -Alfred Cheney Johnston,0.6647291,fineart -Mikalojus Konstantinas Ciurlionis,0.710677,scribbles -Jean-Antoine Watteau,0.71061164,fineart -Paul Delvaux,0.7105914,scribbles -Francesco del Cossa,0.7104901,nudity -Isaac Cordal,0.71046066,weird -Hikari Shimoda,0.7104546,weird -François Boucher,0.67153126,fineart -Akos Major,0.7103802,digipa-high-impact -Bernard Buffet,0.7103491,cartoon -Brandon Woelfel,0.6727086,digipa-high-impact -Edouard Manet,0.7101296,fineart -Auguste Herbin,0.6866145,scribbles -Eugene Delacroix,0.70995826,fineart -L. Birge Harrison,0.70989627,fineart -Howard Pyle,0.70979863,fineart -Diane Dillon,0.70968723,scribbles -Hans Erni,0.7096618,scribbles -Richard Diebenkorn,0.7096184,scribbles -Thomas Gainsborough,0.6759419,fineart -Maria Sibylla Merian,0.7093275,fineart -François Joseph Heim,0.6175854,fineart -E. H. Shepard,0.7091189,cartoon -Hsiao-Ron Cheng,0.7090618,scribbles -Canaletto,0.7090392,fineart -John Atkinson Grimshaw,0.7087531,fineart -Giovanni Battista Tiepolo,0.6754107,fineart -Cornelis van Poelenburgh,0.69821274,fineart -Raina Telgemeier,0.70846486,scribbles -Francesco Hayez,0.6960006,fineart -Gilbert Stuart,0.659772,fineart -Konstantin Yuon,0.7081486,fineart -Antonello da Messina,0.70806944,fineart -Austin Osman Spare,0.7079903,fineart -James Ensor,0.70781446,scribbles -Claude Bonin-Pissarro,0.70739406,fineart -Mikhail Vrubel,0.70738363,fineart -Angelica Kauffman,0.6748828,fineart -Viktor Vasnetsov,0.7072422,fineart -Alphonse Osbert,0.70724136,fineart -Tsutomu Nihei,0.7070495,anime -Harvey Quaytman,0.63613266,fineart -Jamie Hawkesworth,0.706914,digipa-high-impact -Francesco Guardi,0.70682615,fineart -Jean-Honoré Fragonard,0.6518248,fineart -Brice Marden,0.70673287,digipa-high-impact -Charles-Amédée-Philippe van Loo,0.6725916,fineart -Mati Klarwein,0.7066092,n -Gerard ter Borch,0.706589,fineart -Dan Hillier,0.48966256,digipa-med-impact -Federico Barocci,0.682664,fineart -Henri Le Sidaner,0.70637953,fineart -Olivier Bonhomme,0.7063748,scribbles -Edward Weston,0.7061382,black-white -Giovanni Paolo Cavagna,0.6840265,fineart -Germaine Krull,0.6621777,black-white -Hans Holbein the Younger,0.70590156,fineart -François Bocion,0.6272365,fineart -Georg Baselitz,0.7053314,scribbles -Caravaggio,0.7050303,fineart -Anne Rothenstein,0.70502245,scribbles -Wadim Kashin,0.43714935,digipa-low-impact -Heinrich Lefler,0.7048054,fineart -Jacob van Ruisdael,0.7047918,fineart -Bartholomeus van Bassen,0.6676872,fineart -Jeffrey Smith art,0.56750107,fineart -Anne Packard,0.7046703,weird -Jean-François Millet,0.7045456,fineart -Andrey Remnev,0.7041204,digipa-high-impact -Fujiwara Takanobu,0.70410216,ukioe -Elliott Erwitt,0.69950557,black-white -Fern Coppedge,0.7036215,fineart -Bartholomeus van der Helst,0.66411966,fineart -Rembrandt Van Rijn,0.6979987,fineart -Rene Magritte,0.703457,scribbles -Aelbert Cuyp,0.7033657,fineart -Gerda Wegener,0.70319015,scribbles -Graham Sutherland,0.7031714,scribbles -Gerrit Dou,0.7029986,fineart -August Friedrich Schenck,0.6801586,fineart -George Herriman,0.7028568,scribbles -Stanisław Szukalski,0.6903354,fineart -Slim Aarons,0.70222545,digipa-high-impact -Ernst Thoms,0.70221686,fineart -Louis Wain,0.702186,fineart -Artemisia Gentileschi,0.70198226,fineart -Eugène Delacroix,0.70155394,fineart -Peter Bagge,0.70127463,scribbles -Jeffrey Catherine Jones,0.7012148,cartoon -Eugène Carrière,0.65272695,fineart -Alexander Millar,0.7011144,scribbles -Nobuyoshi Araki,0.70108867,fareast -Tintoretto,0.6702795,fineart -André Derain,0.7009005,scribbles -Charles Maurice Detmold,0.70079994,fineart -Francisco de Zurbarán,0.7007234,fineart -Laurie Greasley,0.70072114,cartoon -Lynda Benglis,0.7006948,digipa-high-impact -Cecil Beaton,0.66362655,black-white -Gustaf Tenggren,0.7006041,cartoon -Abdur Rahman Chughtai,0.7004994,ukioe -Constantin Brancusi,0.7004367,scribbles -Mikhail Larionov,0.7004066,fineart -Jan van Kessel the Elder,0.70040506,fineart -Chantal Joffe,0.70036674,scribbles -Charles-André van Loo,0.6830367,fineart -Reginald Marsh,0.6301042,fineart -Elsa Bleda,0.70005083,digipa-high-impact -Peter Paul Rubens,0.65745676,fineart -Eugène Boudin,0.70001304,fineart -Charles Willson Peale,0.66907954,fineart -Brian Mashburn,0.63395154,digipa-high-impact -Barkley L. Hendricks,0.69986427,n -Yoshiyuki Tomino,0.6998095,anime -Guido Reni,0.6416875,fineart -Lynd Ward,0.69958556,fineart -John Constable,0.6907788,fineart -František Kupka,0.6993329,fineart -Pieter Bruegel The Elder,0.6992879,scribbles -Benjamin Gerritsz Cuyp,0.6992173,fineart -Nicolas Mignard,0.6988214,fineart -Augustus Edwin Mulready,0.6482165,fineart -Andrea del Sarto,0.698532,fineart -Edward Steichen,0.69837445,black-white -James Abbott McNeill Whistler,0.69836813,fineart -Alphonse Legros,0.6983243,fineart -Ivan Aivazovsky,0.64588225,fineart -Giovanni Francesco Barbieri,0.6981316,fineart -Grace Cossington Smith,0.69811064,fineart -Bert Stern,0.53411555,scribbles -Mary Cassatt,0.6980135,fineart -Jules Bastien-Lepage,0.69796044,fineart -Max Ernst,0.69777006,fineart -Kentaro Miura,0.697743,anime -Georges Rouault,0.69758564,scribbles -Josephine Wall,0.6973667,fineart -Anne-Louis Girodet,0.58104825,nudity -Bert Hardy,0.6972966,black-white -Adriaen van de Velde,0.69716156,fineart -Andreas Achenbach,0.61108655,fineart -Hayv Kahraman,0.69705284,fineart -Beatrix Potter,0.6969851,fineart -Elmer Bischoff,0.6968948,fineart -Cornelis de Heem,0.6968436,fineart -Inio Asano,0.6965007,anime -Alfred Henry Maurer,0.6964837,fineart -Gottfried Helnwein,0.6962953,digipa-high-impact -Paul Barson,0.54196984,digipa-high-impact -Roger de La Fresnaye,0.69620967,fineart -Abraham Mignon,0.60605425,fineart -Albert Bloch,0.69573116,nudity -Charles Dana Gibson,0.67155975,fineart -Alexandre-Évariste Fragonard,0.6507174,fineart -Ernst Fuchs,0.6953538,nudity -Alfredo Jaar,0.6952965,digipa-high-impact -Judy Chicago,0.6952246,weird -Frans van Mieris the Younger,0.6951849,fineart -Aertgen van Leyden,0.6951305,fineart -Emily Carr,0.69512105,fineart -Frances MacDonald,0.6950408,scribbles -Hannah Höch,0.69495845,scribbles -Gillis Rombouts,0.58770025,fineart -Käthe Kollwitz,0.6947756,fineart -Barbara Stauffacher Solomon,0.6920825,fineart -Georges Lacombe,0.6944455,fineart -Gwen John,0.6944161,fineart -Terada Katsuya,0.6944026,cartoon -James Gillray,0.6871335,fineart -Robert Crumb,0.69420326,fineart -Bruce Pennington,0.6545669,fineart -David Firth,0.69400465,scribbles -Arthur Boyd,0.69399726,fineart -Antonin Artaud,0.67321455,fineart -Giuseppe Arcimboldo,0.6937329,fineart -Jim Mahfood,0.6936606,cartoon -Ossip Zadkine,0.6494374,scribbles -Atelier Olschinsky,0.69349927,fineart -Carl Frederik von Breda,0.57274634,fineart -Ken Sugimori,0.6932626,anime -Chris Friel,0.5399168,fineart -Andrew Macara,0.69307995,fineart -Alexander Jansson,0.69298327,scribbles -Anne Brigman,0.6865817,black-white -George Ault,0.66756654,fineart -Arkhyp Kuindzhi,0.6928072,digipa-high-impact -Emiliano Ponzi,0.69278395,scribbles -William Holman Hunt,0.6927663,fineart -Tamara Lempicka,0.6386007,scribbles -Mark Ryden,0.69259655,fineart -Giovanni Paolo Pannini,0.6802902,fineart -Carl Barks,0.6923666,cartoon -Fritz Bultman,0.6318746,fineart -Salomon van Ruysdael,0.690313,fineart -Carrie Mae Weems,0.6645416,n -Agostino Arrivabene,0.61166185,fineart -Gustave Boulanger,0.655797,fineart -Henry Justice Ford,0.51214355,fareast -Bernardo Strozzi,0.63510317,fineart -André Lhote,0.68718815,scribbles -Paul Corfield,0.6915611,scribbles -Gifford Beal,0.6914777,fineart -Hirohiko Araki,0.6914078,anime -Emil Carlsen,0.691326,fineart -Frans van Mieris the Elder,0.6912799,fineart -Simon Stalenhag,0.6912775,special -Henry van de Velde,0.64838886,fineart -Eleanor Fortescue-Brickdale,0.6909729,fineart -Thomas W Schaller,0.69093937,special -NHK Animation,0.6907677,cartoon -Euan Uglow,0.69060403,scribbles -Hendrick Goltzius,0.69058937,fineart -William Blake,0.69038224,fineart -Vito Acconci,0.58409876,digipa-high-impact -Billy Childish,0.6902057,scribbles -Ben Quilty,0.6875855,fineart -Mark Briscoe,0.69010437,fineart -Adriaen van de Venne,0.6899867,fineart -Alasdair McLellan,0.6898454,digipa-high-impact -Ed Paschke,0.68974686,scribbles -Guy Rose,0.68960273,fineart -Barbara Hepworth,0.68958247,fineart -Edward Henry Potthast,0.6895703,fineart -Francis Bacon,0.6895397,scribbles -Pawel Kuczynski,0.6894536,fineart -Bjarke Ingels,0.68933153,digipa-high-impact -Henry Ossawa Tanner,0.68932164,fineart -Alessandro Allori,0.6892961,fineart -Abraham van Calraet,0.63841593,fineart -Egon Schiele,0.6891415,scribbles -Tim Doyle,0.5474768,digipa-high-impact -Grandma Moses,0.6890782,fineart -John Frederick Kensett,0.61981744,fineart -Giacomo Balla,0.68893707,fineart -Jamie Baldridge,0.6546651,digipa-high-impact -Max Beckmann,0.6884731,scribbles -Cornelis van Haarlem,0.6677613,fineart -Edward Hopper,0.6884258,special -Barkley Hendricks,0.6883637,n -Patrick Dougherty,0.688321,digipa-high-impact -Karol Bak,0.6367705,fineart -Pierre Puvis de Chavannes,0.6880703,fineart -Antoni Tàpies,0.685689,fineart -Alexander Nasmyth,0.57695735,fineart -Laurent Grasso,0.5793272,fineart -Camille Walala,0.6076875,digipa-high-impact -Fairfield Porter,0.68790644,fineart -Alex Colville,0.68787855,fineart -Herb Ritts,0.51471305,scribbles -Gerhard Munthe,0.687658,fineart -Susan Seddon Boulet,0.68762136,scribbles -Liu Ye,0.68760437,fineart -Robert Antoine Pinchon,0.68744636,fineart -Fujiwara Nobuzane,0.6873439,fineart -Frederick Carl Frieseke,0.6873361,fineart -Aert van der Neer,0.6159286,fineart -Allen Jones,0.6869935,scribbles -Anja Millen,0.6064488,digipa-high-impact -Esaias van de Velde,0.68673944,fineart -Gyoshū Hayami,0.68665624,anime -William Hogarth,0.6720842,fineart -Frederic Church,0.6865637,fineart -Cyril Rolando,0.68644965,cartoon -Frederic Edwin Church,0.6863009,fineart -Thomas Rowlandson,0.66726154,fineart -Joachim Brohm,0.68601763,digipa-high-impact -Cristofano Allori,0.6858083,fineart -Adrianus Eversen,0.58259964,fineart -Richard Dadd,0.68546164,fineart -Ambrosius Bosschaert II,0.6854217,fineart -Paolo Veronese,0.68422073,fineart -Abraham van den Tempel,0.66463804,fineart -Duncan Grant,0.6852565,scribbles -Hendrick Cornelisz. van Vliet,0.6851691,fineart -Geof Darrow,0.6851174,scribbles -Émile Bernard,0.6850957,fineart -Brian Bolland,0.68496394,scribbles -James Gilleard,0.6849431,cartoon -Anton Raphael Mengs,0.6689196,fineart -Augustus Jansson,0.6845705,digipa-high-impact -Hendrik Goltzius,0.6843367,fineart -Domenico Quaglio the Younger,0.65769434,fineart -Cicely Mary Barker,0.6841806,fineart -William Eggleston,0.6840795,digipa-high-impact -David Choe,0.6840449,scribbles -Adam Elsheimer,0.6716068,fineart -Heinrich Danioth,0.5390186,fineart -Franz Stuck,0.6836468,fineart -Bernie Wrightson,0.64101505,fineart -Dorina Costras,0.6835419,fineart -El Greco,0.68343943,fineart -Gatōken Shunshi,0.6833314,anime -Giovanni Bellini,0.67622876,fineart -Aron Wiesenfeld,0.68331146,nudity -Boris Kustodiev,0.68329334,fineart -Alec Soth,0.5597321,digipa-high-impact -Artus Scheiner,0.6313348,fineart -Kelly Vivanco,0.6830933,scribbles -Shaun Tan,0.6830649,fineart -Anthony van Dyck,0.6577681,fineart -Neil Welliver,0.68297863,nudity -Robert McCall,0.68294585,fineart -Sandra Chevrier,0.68284667,scribbles -Yinka Shonibare,0.68256056,n -Arthur Tress,0.6301861,digipa-high-impact -Richard McGuire,0.6820089,scribbles -Anni Albers,0.65708244,digipa-high-impact -Aleksey Savrasov,0.65207493,fineart -Wayne Barlowe,0.6537874,fineart -Giorgio de Chirico,0.6815907,fineart -Ernest Procter,0.6815795,fineart -Adriaen Brouwer,0.6815058,fineart -Ilya Glazunov,0.6813533,fineart -Alison Bechdel,0.68096143,scribbles -Carl Holsoe,0.68082225,fineart -Alfred Edward Chalon,0.6464571,fineart -Gerard David,0.68058,fineart -Basil Blackshaw,0.6805679,fineart -Gerrit Adriaenszoon Berckheyde,0.67340267,fineart -George Hendrik Breitner,0.6804209,fineart -Abraham Bloemaert,0.68036544,fineart -Ferdinand Van Kessel,0.67742276,fineart -Hugo Simberg,0.68031186,fineart -Gaston Bussière,0.665221,fineart -Shawn Coss,0.42407864,digipa-low-impact -Hanabusa Itchō,0.68023074,ukioe -Magnus Enckell,0.6801553,fineart -Gary Larson,0.6801336,scribbles -George Manson,0.68013126,digipa-high-impact -Hayao Miyazaki,0.6800754,anime -Carl Spitzweg,0.66581815,fineart -Ambrosius Holbein,0.6798341,fineart -Domenico Pozzi,0.6434162,fineart -Dorothea Tanning,0.6797955,fineart -Jeannette Guichard-Bunel,0.5251578,digipa-high-impact -Victor Moscoso,0.62962687,fineart -Francis Picabia,0.6795391,scribbles -Charles W. Bartlett,0.67947805,fineart -David A Hardy,0.5554935,fineart -C. R. W. Nevinson,0.67946506,fineart -Man Ray,0.6507145,scribbles -Albert Bierstadt,0.67935765,fineart -Charles Le Brun,0.6758479,fineart -Lovis Corinth,0.67913896,fineart -Herbert Abrams,0.5507507,digipa-high-impact -Giorgio Morandi,0.6789025,fineart -Agnolo Bronzino,0.6787985,fineart -Abraham Pether,0.66922426,fineart -John Bauer,0.6786695,fineart -Arthur Stanley Wilkinson,0.67860866,fineart -Arthur Wardle,0.5510789,fineart -George Romney,0.62868094,fineart -Laurie Lipton,0.5201844,fineart -Mickalene Thomas,0.45433685,digipa-low-impact -Alice Rahon,0.6777824,scribbles -Gustave Van de Woestijne,0.6777346,scribbles -Laurel Burch,0.67766285,fineart -Hendrik Gerritsz Pot,0.67750573,fineart -John William Waterhouse,0.677472,fineart -Conor Harrington,0.5967809,fineart -Gabriel Ba,0.6773366,cartoon -Franz Xaver Winterhalter,0.62229514,fineart -George Cruikshank,0.6473593,fineart -Hyacinthe Rigaud,0.67717785,fineart -Cornelis Claesz van Wieringen,0.6770269,fineart -Adriaen van Outrecht,0.67682564,fineart -Yaacov Agam,0.6767926,fineart -Franz von Lenbach,0.61948,fineart -Clyfford Still,0.67667866,fineart -Alexander Roslin,0.66719526,fineart -Barry Windsor Smith,0.6765375,cartoon -Takeshi Obata,0.67643225,anime -John Harris,0.47712502,fineart -Bruce Davidson,0.6763525,digipa-high-impact -Hendrik Willem Mesdag,0.6762745,fineart -Makoto Shinkai,0.67610705,anime -Andreas Gursky,0.67610145,digipa-high-impact -Mike Winkelmann (Beeple),0.6510196,digipa-high-impact -Gustave Moreau,0.67607844,fineart -Frank Weston Benson,0.6760142,fineart -Eduardo Kingman,0.6759026,fineart -Benjamin Williams Leader,0.5611925,fineart -Hervé Guibert,0.55973417,black-white -Cornelis Dusart,0.6753622,fineart -Amédée Guillemin,0.6752696,fineart -Alessio Albi,0.6752633,digipa-high-impact -Matthias Grünewald,0.6751779,fineart -Fujishima Takeji,0.6751577,anime -Georges Braque,0.67514753,scribbles -John Salminen,0.67498183,fineart -Atey Ghailan,0.674873,scribbles -Giovanni Antonio Galli,0.657484,fineart -Julie Mehretu,0.6748382,fineart -Jean Auguste Dominique Ingres,0.6746286,fineart -Francesco Albani,0.6621554,fineart -Anato Finnstark,0.6744919,digipa-high-impact -Giovanni Bernardino Mazzolini,0.64416045,fineart -Antoine Le Nain,0.6233709,fineart -Ford Madox Brown,0.6743224,fineart -Gerhard Richter,0.67426133,fineart -theCHAMBA,0.6742506,cartoon -Edward Julius Detmold,0.67421955,fineart -George Stubbs,0.6209227,fineart -George Tooker,0.6740602,scribbles -Faith Ringgold,0.6739976,scribbles -Giambattista Pittoni,0.5792371,fineart -George Bellows,0.6737008,fineart -Aldus Manutius,0.67366326,fineart -Ambrosius Bosschaert,0.67364097,digipa-high-impact -Michael Parkes,0.6133628,fineart -Hans Bellmer,0.6735973,nudity -Sir James Guthrie,0.67359626,fineart -Charles Spencelayh,0.67356884,fineart -Ivan Shishkin,0.6734136,fineart -Hans Holbein the Elder,0.6733856,fineart -Filip Hodas,0.60053295,digipa-high-impact -Herman Saftleven,0.6732188,digipa-high-impact -Dirck de Quade van Ravesteyn,0.67309594,fineart -Joe Fenton,0.6730916,scribbles -Arnold Bocklin,0.6730706,fineart -Baiōken Eishun,0.6730663,anime -Giovanni Giacometti,0.6730505,fineart -Giovanni Battista Gaulli,0.65036476,fineart -William Stout,0.672887,fineart -Gavin Hamilton,0.5982757,fineart -John Stezaker,0.6726847,black-white -Frederick McCubbin,0.67263377,fineart -Christoph Ludwig Agricola,0.62750757,fineart -Alice Neel,0.67255914,scribbles -Giovanni Battista Venanzi,0.61996603,fineart -Miho Hirano,0.6724092,anime -Tom Thomson,0.6723876,fineart -Alfred Munnings,0.6723851,fineart -David Wilkie,0.6722781,fineart -Adriaen van Ostade,0.67220736,fineart -Alfred Eisenstaedt,0.67213774,black-white -Leon Kossoff,0.67208946,fineart -Georges de La Tour,0.6421979,fineart -Chuck Close,0.6719756,digipa-high-impact -Herbert MacNair,0.6719506,scribbles -Edward Atkinson Hornel,0.6719265,fineart -Becky Cloonan,0.67192084,cartoon -Gian Lorenzo Bernini,0.58210254,fineart -Hein Gorny,0.4982776,digipa-med-impact -Joe Webb,0.6714884,fineart -Cornelis Pietersz Bega,0.64423996,fineart -Christian Krohg,0.6713641,fineart -Cornelia Parker,0.6712246,fineart -Anna Mary Robertson Moses,0.6709144,fineart -Quentin Tarantino,0.6708354,digipa-high-impact -Frederic Remington,0.67074275,fineart -Barent Fabritius,0.6707407,fineart -Oleg Oprisco,0.6707388,digipa-high-impact -Hendrick van Streeck,0.670666,fineart -Bakemono Zukushi,0.67051035,anime -Lucy Madox Brown,0.67032814,fineart -Paul Wonner,0.6700563,scribbles -Guido Borelli Da Caluso,0.66966087,digipa-high-impact -Emil Alzamora,0.5844039,nudity -Heinrich Brocksieper,0.64469147,fineart -Dan Smith,0.669563,digipa-high-impact -Lois van Baarle,0.6695091,scribbles -Arthur Garfield Dove,0.6694996,scribbles -Matthias Jung,0.66936135,digipa-high-impact -José Clemente Orozco,0.6693544,scribbles -Don Bluth,0.6693046,cartoon -Akseli Gallen-Kallela,0.66927314,fineart -Alex Howitt,0.52858865,digipa-high-impact -Giovanni Bernardino Asoleni,0.6635405,fineart -Frederick Goodall,0.6690712,fineart -Francesco Bartolozzi,0.63431,fineart -Edmund Leighton,0.6689639,fineart -Abraham Willaerts,0.5966594,fineart -François Louis Thomas Francia,0.6207474,fineart -Carel Fabritius,0.6688478,fineart -Flora Macdonald Reid,0.6687404,fineart -Bartholomeus Breenbergh,0.6163084,fineart -Bernardino Mei,0.6486895,fineart -Carel Weight,0.6684968,fineart -Aristide Maillol,0.66843045,scribbles -Chris Leib,0.60567486,fineart -Giovanni Battista Piazzetta,0.65012705,fineart -Daniel Maclise,0.6678073,fineart -Giovanni Bernardino Azzolini,0.65774256,fineart -Aaron Horkey,0.6676864,fineart -Otto Dix,0.667294,scribbles -Ferdinand Bol,0.6414797,fineart -Adriaen Coorte,0.6670663,fineart -William Gropper,0.6669881,scribbles -Gerard de Lairesse,0.6639489,fineart -Mab Graves,0.6668356,scribbles -Fernando Amorsolo,0.66683346,fineart -Pixar Concept Artists,0.6667752,cartoon -Alfred Augustus Glendening,0.64009607,fineart -Diego Velázquez,0.6666799,fineart -Jerry Pinkney,0.6665478,fineart -Antoine Wiertz,0.6143825,fineart -Alberto Burri,0.6618252,scribbles -Max Weber,0.6664029,fineart -Hans Baluschek,0.66636246,fineart -Annie Swynnerton,0.6663346,fineart -Albert Dubois-Pillet,0.57526016,fineart -Dora Maar,0.62862253,digipa-high-impact -Kay Sage,0.5614823,fineart -David A. Hardy,0.51376164,fineart -Alberto Biasi,0.42917693,digipa-low-impact -Fra Bartolomeo,0.6661105,fineart -Hendrick van Balen,0.65754294,fineart -Edwin Austin Abbey,0.66596496,fineart -George Frederic Watts,0.66595024,fineart -Alexei Kondratyevich Savrasov,0.6470352,fineart -Anna Ancher,0.66581213,fineart -Irma Stern,0.66580737,fineart -Frédéric Bazille,0.6657115,fineart -Awataguchi Takamitsu,0.6656272,anime -Edward Sorel,0.6655388,fineart -Edward Lear,0.6655078,fineart -Gabriel Metsu,0.6654555,fineart -Giovanni Battista Innocenzo Colombo,0.6653655,fineart -Scott Naismith,0.6650656,fineart -John Perceval,0.6650283,fineart -Girolamo Muziano,0.64234406,fineart -Cornelis de Man,0.66494393,fineart -Cornelis Bisschop,0.64119905,digipa-high-impact -Hans Leu the Elder,0.64770013,fineart -Michael Hutter,0.62479556,fineart -Cornelia MacIntyre Foley,0.6510235,fineart -Todd McFarlane,0.6647763,cartoon -John James Audubon,0.6279882,digipa-high-impact -William Henry Hunt,0.57340264,fineart -John Anster Fitzgerald,0.6644317,fineart -Tomer Hanuka,0.6643152,cartoon -Alex Prager,0.6641814,fineart -Heinrich Kley,0.6641148,fineart -Anne Redpath,0.66407835,scribbles -Marianne North,0.6640104,fineart -Daniel Merriam,0.6639365,fineart -Bill Carman,0.66390574,fineart -Méret Oppenheim,0.66387725,digipa-high-impact -Erich Heckel,0.66384083,fineart -Iryna Yermolova,0.663623,fineart -Antoine Ignace Melling,0.61502695,fineart -Akira Toriyama,0.6635002,anime -Gregory Crewdson,0.59810174,digipa-high-impact -Helene Schjerfbeck,0.66333634,fineart -Antonio Mancini,0.6631618,fineart -Zanele Muholi,0.58554715,n -Balthasar van der Ast,0.66294503,fineart -Toei Animations,0.6629127,anime -Arthur Quartley,0.6628106,fineart -Diego Rivera,0.6625808,fineart -Hendrik van Steenwijk II,0.6623777,fineart -James Tissot,0.6623415,fineart -Kehinde Wiley,0.66218376,n -Chiharu Shiota,0.6621249,digipa-high-impact -George Grosz,0.6620224,fineart -Peter De Seve,0.6616659,cartoon -Ryan Hewett,0.6615638,fineart -Hasegawa Tōhaku,0.66146004,anime -Apollinary Vasnetsov,0.6613177,fineart -Francis Cadell,0.66119456,fineart -Henri Harpignies,0.6611012,fineart -Henry Macbeth-Raeburn,0.6213787,fineart -Christoffel van den Berghe,0.6609149,fineart -Leiji Matsumoto,0.66089404,anime -Adriaen van der Werff,0.638286,fineart -Ramon Casas,0.6606529,fineart -Arthur Hacker,0.66062653,fineart -Edward Willis Redfield,0.66058433,fineart -Carl Gustav Carus,0.65355223,fineart -Francesca Woodman,0.60435605,digipa-high-impact -Hans Makart,0.5881955,fineart -Carne Griffiths,0.660091,weird -Will Barnet,0.65995145,scribbles -Fitz Henry Lane,0.659841,fineart -Masaaki Sasamoto,0.6597158,anime -Salvador Dali,0.6290813,scribbles -Walt Kelly,0.6596993,digipa-high-impact -Charlotte Nasmyth,0.56481636,fineart -Ferdinand Knab,0.6596528,fineart -Steve Lieber,0.6596117,scribbles -Zhang Kechun,0.6595939,fareast -Olivier Valsecchi,0.5324838,digipa-high-impact -Joel Meyerowitz,0.65937585,digipa-high-impact -Arthur Streeton,0.6592294,fineart -Henriett Seth F.,0.6592273,fineart -Genndy Tartakovsky,0.6591695,scribbles -Otto Marseus van Schrieck,0.65890455,fineart -Hanna-Barbera,0.6588123,cartoon -Mary Anning,0.6588001,fineart -Pamela Colman Smith,0.6587648,fineart -Anton Mauve,0.6586873,fineart -Hendrick Avercamp,0.65866685,fineart -Max Pechstein,0.65860206,scribbles -Franciszek Żmurko,0.56855476,fineart -Felice Casorati,0.6584761,fineart -Louis Janmot,0.65298057,fineart -Thomas Cole,0.5408042,fineart -Peter Mohrbacher,0.58273685,fineart -Arnold Franz Brasz,0.65834284,nudity -Christian Rohlfs,0.6582814,fineart -Basil Gogos,0.658105,fineart -Fitz Hugh Lane,0.657923,fineart -Liubov Sergeevna Popova,0.62325525,fineart -Elizabeth MacNicol,0.65773135,fineart -Zinaida Serebriakova,0.6577016,fineart -Ernest Lawson,0.6575238,fineart -Bruno Catalano,0.6574354,fineart -Albert Namatjira,0.6573372,fineart -Fritz von Uhde,0.6572697,fineart -Edwin Henry Landseer,0.62363374,fineart -Naoto Hattori,0.621745,fareast -Reylia Slaby,0.65709853,fineart -Arthur Burdett Frost,0.6147318,fineart -Frank Miller,0.65707314,digipa-high-impact -Algernon Talmage,0.65702903,fineart -Itō Jakuchū,0.6570199,digipa-high-impact -Billie Waters,0.65684533,digipa-high-impact -Ingrid Baars,0.58558,digipa-high-impact -Pieter Jansz Saenredam,0.6566058,fineart -Egbert van Heemskerck,0.6125889,fineart -John French Sloan,0.6362145,fineart -Craola,0.65639997,scribbles -Benjamin Marra,0.61809736,nudity -Anthony Thieme,0.65609205,fineart -Satoshi Kon,0.65606606,anime -Masamune Shirow,0.65592873,anime -Alfred Stevens,0.6557321,fineart -Hariton Pushwagner,0.6556745,anime -Carlo Carrà,0.6556279,fineart -Stuart Davis,0.6050534,digipa-high-impact -David Shrigley,0.6553904,digipa-high-impact -Albrecht Anker,0.65531695,fineart -Anton Semenov,0.6552501,digipa-high-impact -Fabio Hurtado,0.5955889,fineart -Donald Judd,0.6552257,fineart -Francisco de Burgos Mantilla,0.65516514,fineart -Barthel Bruyn the Younger,0.6551433,fineart -Abram Arkhipov,0.6550962,fineart -Paulus Potter,0.65498203,fineart -Edward Lamson Henry,0.6549521,fineart -Audrey Kawasaki,0.654843,fineart -George Catlin,0.6547183,fineart -Adélaïde Labille-Guiard,0.6066263,fineart -Sandy Skoglund,0.6546999,digipa-high-impact -Hans Baldung,0.654431,fineart -Ethan Van Sciver,0.65442884,cartoon -Frans Hals,0.6542338,fineart -Caspar David Friedrich,0.6542175,fineart -Charles Conder,0.65420866,fineart -Betty Churcher,0.65387225,fineart -Claes Corneliszoon Moeyaert,0.65386075,fineart -David Bomberg,0.6537477,fineart -Abraham Bosschaert,0.6535562,fineart -Giuseppe de Nittis,0.65354455,fineart -John La Farge,0.65342575,fineart -Frits Thaulow,0.65341854,fineart -John Duncan,0.6532379,fineart -Floris van Dyck,0.64900756,fineart -Anton Pieck,0.65310377,fineart -Roger Dean,0.6529647,nudity -Maximilian Pirner,0.65280807,fineart -Dorothy Johnstone,0.65267503,fineart -Govert Dircksz Camphuysen,0.65258145,fineart -Ryohei Hase,0.6168618,fineart -Hans von Aachen,0.62437224,fineart -Gustaf Munch-Petersen,0.6522485,fineart -Earnst Haeckel,0.6344333,fineart -Giovanni Battista Bracelli,0.62635326,fineart -Hendrick Goudt,0.6521433,fineart -Aneurin Jones,0.65191466,fineart -Bryan Hitch,0.6518333,cartoon -Coby Whitmore,0.6515695,fineart -Barthélemy d'Eyck,0.65156406,fineart -Quint Buchholz,0.65151155,fineart -Adriaen Hanneman,0.6514815,fineart -Tom Roberts,0.5855832,fineart -Fernand Khnopff,0.6512954,nudity -Charles Vess,0.6512271,cartoon -Carlo Galli Bibiena,0.6511681,nudity -Alexander Milne Calder,0.6081027,fineart -Josan Gonzalez,0.6193469,cartoon -Barthel Bruyn the Elder,0.6509954,fineart -Jon Whitcomb,0.6046063,fineart -Arcimboldo,0.6509897,fineart -Hendrik van Steenwijk I,0.65086293,fineart -Albert Joseph Pénot,0.65085316,fineart -Edward Wadsworth,0.6308917,scribbles -Andrew Wyeth,0.6507103,fineart -Correggio,0.650689,fineart -Frances Currey,0.65068,fineart -Henryk Siemiradzki,0.56721973,fineart -Worthington Whittredge,0.6504713,fineart -Federico Zandomeneghi,0.65033823,fineart -Isaac Levitan,0.6503356,fineart -Russ Mills,0.65012795,fineart -Edith Lawrence,0.65010095,fineart -Gil Elvgren,0.5614284,digipa-high-impact -Chris Foss,0.56495357,fineart -Francesco Zuccarelli,0.612805,fineart -Hendrick Bloemaert,0.64962655,fineart -Egon von Vietinghoff,0.57180583,fineart -Pixar,0.6495793,cartoon -Daniel Clowes,0.6495775,fineart -Friedrich Ritter von Friedländer-Malheim,0.6493772,fineart -Rebecca Sugar,0.6492679,scribbles -Chen Daofu,0.6492026,fineart -Dustin Nguyen,0.64909416,cartoon -Raymond Duchamp-Villon,0.6489605,nudity -Daniel Garber,0.6489332,fineart -Antonio Canova,0.58764786,fineart -Algernon Blackwood,0.59256804,fineart -Betye Saar,0.64877665,fineart -William S. Burroughs,0.5505619,fineart -Rodney Matthews,0.64844495,fineart -Michelangelo Buonarroti,0.6484401,fineart -Posuka Demizu,0.64843124,anime -Joao Ruas,0.6484134,fineart -Andy Fairhurst,0.6480388,special -"Andries Stock, Dutch Baroque painter",0.6479797,fineart -Antonio de la Gandara,0.6479292,fineart -Bruce Timm,0.6477877,scribbles -Harvey Kurtzman,0.64772683,cartoon -Eiichiro Oda,0.64772165,anime -Edwin Landseer,0.6166703,fineart -Carl Heinrich Bloch,0.64755356,fineart -Adriaen Isenbrant,0.6475428,fineart -Santiago Caruso,0.6473954,fineart -Alfred Guillou,0.6472603,fineart -Clara Peeters,0.64725095,fineart -Kim Jung Gi,0.6472225,cartoon -Milo Manara,0.6471776,cartoon -Phil Noto,0.6470769,anime -Kaws,0.6470336,cartoon -Desmond Morris,0.5951916,fineart -Gediminas Pranckevicius,0.6467787,fineart -Jack Kirby,0.6467424,cartoon -Claes Jansz. Visscher,0.6466888,fineart -Augustin Meinrad Bächtiger,0.6465789,fineart -John Lavery,0.64643383,fineart -Anne Bachelier,0.6464065,fineart -Giuseppe Bernardino Bison,0.64633006,fineart -E. T. A. Hoffmann,0.5887251,fineart -Ambrosius Benson,0.6457839,fineart -Cornelis Verbeeck,0.645782,fineart -H. R. Giger,0.6456823,weird -Adolph Menzel,0.6455246,fineart -Aliza Razell,0.5863178,digipa-high-impact -Gerard Seghers,0.6205679,fineart -David Aja,0.62812066,scribbles -Gustave Courbet,0.64476407,fineart -Alexandre Cabanel,0.63849115,fineart -Albert Marquet,0.64471006,fineart -Harold Harvey,0.64464307,fineart -William Wegman,0.6446265,scribbles -Harold Gilman,0.6445966,fineart -Jeremy Geddes,0.57839495,digipa-high-impact -Abraham van Beijeren,0.6356113,fineart -Eugène Isabey,0.6160607,fineart -Jorge Jacinto,0.58618563,fineart -Frederic Leighton,0.64383554,fineart -Dave McKean,0.6438012,cartoon -Hiromu Arakawa,0.64371413,anime -Aaron Douglas,0.6437089,fineart -Adolf Dietrich,0.590169,fineart -Frederik de Moucheron,0.6435952,fineart -Siya Oum,0.6435919,cartoon -Alberto Morrocco,0.64352196,fineart -Robert Vonnoh,0.6433115,fineart -Tom Bagshaw,0.5322264,fineart -Guerrilla Girls,0.64309967,digipa-high-impact -Johann Wolfgang von Goethe,0.6429888,fineart -Charles Le Roux,0.6426594,fineart -Auguste Toulmouche,0.64261353,fineart -Cindy Sherman,0.58666563,digipa-high-impact -Federico Zuccari,0.6425021,fineart -Mike Mignola,0.642346,cartoon -Cecily Brown,0.6421981,fineart -Brian K. Vaughan,0.64147836,cartoon -RETNA (Marquis Lewis),0.47963,n -Klaus Janson,0.64129144,cartoon -Alessandro Galli Bibiena,0.6412889,fineart -Jeremy Lipking,0.64123213,fineart -Stephen Shore,0.64108944,digipa-high-impact -Heinz Edelmann,0.51325977,digipa-med-impact -Joaquín Sorolla,0.6409732,fineart -Bella Kotak,0.6409608,digipa-high-impact -Cornelis Engebrechtsz,0.64091057,fineart -Bruce Munro,0.64084166,digipa-high-impact -Marjane Satrapi,0.64076495,fineart -Jeremy Mann,0.557744,digipa-high-impact -Heinrich Maria Davringhausen,0.6403986,fineart -Kengo Kuma,0.6402023,digipa-high-impact -Alfred Manessier,0.640153,fineart -Antonio Galli Bibiena,0.6399247,digipa-high-impact -Eduard von Grützner,0.6397164,fineart -Bunny Yeager,0.5455078,digipa-high-impact -Adolphe Willette,0.6396935,fineart -Wangechi Mutu,0.6394607,n -Peter Milligan,0.6391612,digipa-high-impact -Dalí,0.45400402,digipa-low-impact -Élisabeth Vigée Le Brun,0.6388982,fineart -Beth Conklin,0.6388204,digipa-high-impact -Charles Alphonse du Fresnoy,0.63881266,fineart -Thomas Benjamin Kennington,0.56668127,fineart -Jim Woodring,0.5625168,fineart -Francisco Oller,0.63846034,fineart -Csaba Markus,0.6384506,fineart -Botero,0.63843524,scribbles -Bill Henson,0.5394536,digipa-high-impact -Anna Bocek,0.6382304,scribbles -Hugo van der Goes,0.63822484,fineart -Robert William Hume,0.5433574,fineart -Chip Zdarsky,0.6381826,cartoon -Daniel Seghers,0.53494316,fineart -Richard Doyle,0.6377541,fineart -Hendrick Terbrugghen,0.63773805,fineart -Joe Madureira,0.6377177,special -Floris van Schooten,0.6376191,fineart -Jeff Simpson,0.3959046,fineart -Albert Joseph Moore,0.6374316,fineart -Arthur Merric Boyd,0.6373228,fineart -Amadeo de Souza Cardoso,0.5927926,fineart -Os Gemeos,0.6368859,digipa-high-impact -Giovanni Boldini,0.6368698,fineart -Albert Goodwin,0.6368695,fineart -Hans Eduard von Berlepsch-Valendas,0.61562145,fineart -Edmond Xavier Kapp,0.5758474,fineart -François Quesnel,0.6365935,fineart -Nathan Coley,0.6365817,digipa-high-impact -Jasmine Becket-Griffith,0.6365083,digipa-high-impact -Raphaelle Peale,0.6364422,fineart -Candido Portinari,0.63634276,fineart -Edward Dugmore,0.63179636,fineart -Anders Zorn,0.6361722,fineart -Ed Emshwiller,0.63615763,fineart -Francis Coates Jones,0.6361159,fineart -Ernst Haas,0.6361123,digipa-high-impact -Dirck van Baburen,0.6213001,fineart -René Lalique,0.63594735,fineart -Sydney Prior Hall,0.6359345,fineart -Brad Kunkle,0.5659712,fineart -Corneille,0.6356381,fineart -Henry Lamb,0.63560975,fineart -Dirck Hals,0.63559663,fineart -Alex Grey,0.62908936,nudity -Michael Heizer,0.63555753,fineart -Yiannis Moralis,0.61731136,fineart -Emily Murray Paterson,0.4392335,fineart -Georg Friedrich Kersting,0.6256248,fineart -Frances Hodgkins,0.6352128,fineart -Charles Cundall,0.6349486,fineart -Henry Wallis,0.63478243,fineart -Goro Fujita,0.6346491,cartoon -Jean-Léon Gérôme,0.5954844,fineart -August von Pettenkofen,0.60910493,fineart -Abbott Handerson Thayer,0.63428533,fineart -Martin John Heade,0.5926603,fineart -Ellen Jewett,0.63420236,digipa-high-impact -Hidari Jingorō,0.63388014,fareast -Taiyō Matsumoto,0.63372946,special -Emanuel Leutze,0.6007246,fineart -Adam Martinakis,0.48973057,digipa-med-impact -Will Eisner,0.63349223,cartoon -Alexander Stirling Calder,0.6331682,fineart -Saturno Butto,0.6331184,nudity -Cecilia Beaux,0.6330725,fineart -Amandine Van Ray,0.6174208,digipa-high-impact -Bob Eggleton,0.63277495,digipa-high-impact -Sherree Valentine Daines,0.63274443,fineart -Frederick Lord Leighton,0.6299176,fineart -Daniel Ridgway Knight,0.63251615,fineart -Gaetano Previati,0.61743724,fineart -John Berkey,0.63226986,fineart -Richard Misrach,0.63201725,digipa-high-impact -Aaron Jasinski,0.57948315,fineart -"Edward Otho Cresap Ord, II",0.6317712,fineart -Evelyn De Morgan,0.6317376,fineart -Noelle Stevenson,0.63159716,digipa-high-impact -Edward Robert Hughes,0.6315573,fineart -Allan Ramsay,0.63150716,fineart -Balthus,0.6314323,scribbles -Hendrick Cornelisz Vroom,0.63143134,digipa-high-impact -Ilya Repin,0.6313043,fineart -George Lambourn,0.6312267,fineart -Arthur Hughes,0.6310194,fineart -Antonio J. Manzanedo,0.53841716,fineart -John Singleton Copley,0.6264835,fineart -Dennis Miller Bunker,0.63078755,fineart -Ernie Barnes,0.6307126,cartoon -Alison Kinnaird,0.6306353,digipa-high-impact -Alex Toth,0.6305541,digipa-high-impact -Henry Raeburn,0.6155551,fineart -Alice Bailly,0.6305177,fineart -Brian Kesinger,0.63037646,scribbles -Antoine Blanchard,0.63036835,fineart -Ron Walotsky,0.63035095,fineart -Kent Monkman,0.63027304,fineart -Naomi Okubo,0.5782754,fareast -Hercules Seghers,0.62957174,fineart -August Querfurt,0.6295643,fineart -Samuel Melton Fisher,0.6283333,fineart -David Burdeny,0.62950236,digipa-high-impact -George Bain,0.58519644,fineart -Peter Holme III,0.62938106,fineart -Grayson Perry,0.62928164,digipa-high-impact -Chris Claremont,0.6292076,digipa-high-impact -Dod Procter,0.6291759,fineart -Huang Tingjian,0.6290358,fareast -Dorothea Warren O'Hara,0.6290113,fineart -Ivan Albright,0.6289551,fineart -Hubert von Herkomer,0.6288955,fineart -Barbara Nessim,0.60589516,digipa-high-impact -Henry Scott Tuke,0.6286309,fineart -Ditlev Blunck,0.6282925,fineart -Sven Nordqvist,0.62828535,fineart -Lee Madgwick,0.6281731,fineart -Hubert van Eyck,0.6281529,fineart -Edmond Bille,0.62339354,fineart -Ejnar Nielsen,0.6280824,fineart -Arturo Souto,0.6280583,fineart -Jean Giraud,0.6279888,fineart -Storm Thorgerson,0.6277394,digipa-high-impact -Ed Benedict,0.62764007,digipa-high-impact -Christoffer Wilhelm Eckersberg,0.6014842,fineart -Clarence Holbrook Carter,0.5514105,fineart -Dorothy Lockwood,0.6273235,fineart -John Singer Sargent,0.6272487,fineart -Brigid Derham,0.6270125,digipa-high-impact -Henricus Hondius II,0.6268505,fineart -Gertrude Harvey,0.5903887,fineart -Grant Wood,0.6266253,fineart -Fyodor Vasilyev,0.5234919,digipa-med-impact -Cagnaccio di San Pietro,0.6261671,fineart -Doris Boulton-Maude,0.62593174,fineart -Adolf Hirémy-Hirschl,0.5946784,fineart -Harold von Schmidt,0.6256755,fineart -Martine Johanna,0.6256161,digipa-high-impact -Gerald Kelly,0.5579602,digipa-high-impact -Ub Iwerks,0.625396,cartoon -Dirck van der Lisse,0.6253871,fineart -Edouard Riou,0.6250113,fineart -Ilya Yefimovich Repin,0.62491584,fineart -Martin Johnson Heade,0.59421235,fineart -Afarin Sajedi,0.62475824,scribbles -Alfred Thompson Bricher,0.6247515,fineart -Edwin G. Lucas,0.5553578,fineart -Georges Emile Lebacq,0.56175387,fineart -Francis Davis Millet,0.5988504,fineart -Bill Sienkiewicz,0.6125557,digipa-high-impact -Giocondo Albertolli,0.62441677,fineart -Victor Nizovtsev,0.6242258,fineart -Squeak Carnwath,0.62416434,digipa-high-impact -Bill Viola,0.62409425,digipa-high-impact -Annie Abernethie Pirie Quibell,0.6240767,fineart -Jason Edmiston,0.62405366,fineart -Al Capp,0.6239494,fineart -Kobayashi Kiyochika,0.6239368,anime -Albert Anker,0.62389827,fineart -Iain Faulkner,0.62376785,fineart -Todd Schorr,0.6237408,fineart -Charles Ginner,0.62370133,fineart -Emile Auguste Carolus-Duran,0.62353987,fineart -John Philip Falter,0.623418,cartoon -Chizuko Yoshida,0.6233001,fareast -Anna Dittmann,0.62327325,cartoon -Henry Snell Gamley,0.62319934,fineart -Edmund Charles Tarbell,0.6230626,fineart -Rob Gonsalves,0.62298363,fineart -Gladys Dawson,0.6228511,fineart -Tomma Abts,0.61153626,fineart -Kate Beaton,0.53993124,digipa-high-impact -Gustave Buchet,0.62243867,fineart -Gareth Pugh,0.6223551,digipa-high-impact -Caspar van Wittel,0.57871693,fineart -Anton Otto Fischer,0.6222941,fineart -Albert Guillaume,0.56529653,fineart -Felix Octavius Carr Darley,0.62223387,fineart -Bernard van Orley,0.62221646,fineart -Edward John Poynter,0.60147405,fineart -Walter Percy Day,0.62207425,fineart -Franciszek Starowieyski,0.5709621,fineart -Auguste Baud-Bovy,0.6219854,fineart -Chris LaBrooy,0.45497298,digipa-low-impact -Abraham de Vries,0.5859101,fineart -Antoni Gaudi,0.62162614,fineart -Joe Jusko,0.62156093,digipa-high-impact -Lynda Barry,0.62154603,digipa-high-impact -Michal Karcz,0.62154436,digipa-high-impact -Raymond Briggs,0.62150294,fineart -Herbert James Gunn,0.6210927,fineart -Dwight William Tryon,0.620984,fineart -Paul Henry,0.5752968,fineart -Helio Oiticica,0.6203739,digipa-high-impact -Sebastian Errazuriz,0.62036186,digipa-high-impact -Lucian Freud,0.6203146,nudity -Frank Auerbach,0.6201102,weird -Andre-Charles Boulle,0.6200789,fineart -Franz Fedier,0.5669752,fineart -Austin Briggs,0.57675314,fineart -Hugo Sánchez Bonilla,0.61978436,digipa-high-impact -Caroline Chariot-Dayez,0.6195682,digipa-high-impact -Bill Ward,0.61953044,digipa-high-impact -Charles Bird King,0.6194487,fineart -Adrian Ghenie,0.6193521,digipa-high-impact -Agnes Cecile,0.6192814,digipa-high-impact -Augustus John,0.6191995,fineart -Jeffrey T. Larson,0.61913544,fineart -Alexis Simon Belle,0.3190395,digipa-low-impact -Jean-Baptiste Monge,0.5758537,fineart -Adolf Bierbrauer,0.56129396,fineart -Ayako Rokkaku,0.61891204,fareast -Lisa Keene,0.54570895,digipa-high-impact -Edmond Aman-Jean,0.57168096,fineart -Marc Davis,0.61837333,cartoon -Cerith Wyn Evans,0.61829346,digipa-high-impact -George Wyllie,0.61829203,fineart -George Luks,0.6182724,fineart -William-Adolphe Bouguereau,0.618265,c -Grigoriy Myasoyedov,0.61801606,fineart -Hashimoto Gahō,0.61795104,fineart -Charles Ragland Bunnell,0.61772746,fineart -Ambrose McCarthy Patterson,0.61764514,fineart -Bill Brauer,0.5824066,fineart -Mikko Lagerstedt,0.591015,digipa-high-impact -Koson Ohara,0.53635323,fineart -Evaristo Baschenis,0.5857368,fineart -Martin Ansin,0.5294119,fineart -Cory Loftis,0.6168619,cartoon -Joseph Stella,0.6166778,fineart -André Pijet,0.5768274,fineart -Jeff Wall,0.6162895,digipa-high-impact -Eleanor Layfield Davis,0.6158844,fineart -Saul Tepper,0.61579347,fineart -Alex Hirsch,0.6157384,cartoon -Alexandre Falguière,0.55011404,fineart -Malcolm Liepke,0.6155646,fineart -Georg Friedrich Schmidt,0.60364646,fineart -Hendrik Kerstens,0.55099905,digipa-high-impact -Félix Bódog Widder,0.6153954,fineart -Marie Guillemine Benoist,0.61532974,fineart -Kelly Mckernan,0.60047054,digipa-high-impact -Ignacio Zuloaga,0.6151608,fineart -Hubert van Ravesteyn,0.61489964,fineart -Angus McKie,0.61487424,digipa-high-impact -Colin Campbell Cooper,0.6147882,fineart -Pieter Aertsen,0.61454165,fineart -Jan Brett,0.6144608,fineart -Kazuo Koike,0.61438507,fineart -Edith Grace Wheatley,0.61428297,fineart -Ogawa Kazumasa,0.61427975,fareast -Giovanni Battista Cipriani,0.6022825,fineart -André Bauchant,0.57124996,fineart -George Abe,0.6140447,digipa-high-impact -Georges Lemmen,0.6139967,scribbles -Frank Leonard Brooks,0.6139327,fineart -Gai Qi,0.613744,anime -Frank Gehry,0.6136776,digipa-high-impact -Anton Domenico Gabbiani,0.55471313,fineart -Cassandra Austen,0.6135781,fineart -Paul Gustav Fischer,0.613273,fineart -Emiliano Di Cavalcanti,0.6131207,fineart -Meryl McMaster,0.6129995,digipa-high-impact -Domenico di Pace Beccafumi,0.6129922,fineart -Ludwig Mies van der Rohe,0.6126692,fineart -Étienne-Louis Boullée,0.6126158,fineart -Dali,0.5928694,nudity -Shinji Aramaki,0.61246127,anime -Giovanni Fattori,0.59544694,fineart -Bapu,0.6122084,c -Raphael Lacoste,0.5539114,digipa-high-impact -Scarlett Hooft Graafland,0.6119631,digipa-high-impact -Rene Laloux,0.61190474,fineart -Julius Horsthuis,0.59037095,fineart -Gerald van Honthorst,0.6115939,fineart -Dino Valls,0.611533,fineart -Tony DiTerlizzi,0.6114657,cartoon -Michael Cheval,0.61138546,anime -Charles Schulz,0.6113759,digipa-high-impact -Alvar Aalto,0.61122143,digipa-high-impact -Gu Kaizhi,0.6110798,fareast -Eugene von Guerard,0.6109776,fineart -John Cassaday,0.610949,fineart -Elizabeth Forbes,0.61092335,fineart -Edmund Greacen,0.6109115,fineart -Eugène Burnand,0.6107876,fineart -Boris Grigoriev,0.6107853,scribbles -Norman Rockwell,0.6107638,fineart -Barthélemy Menn,0.61064315,fineart -George Biddle,0.61058354,fineart -Edgar Ainsworth,0.5525424,digipa-high-impact -Alfred Leyman,0.5887217,fineart -Tex Avery,0.6104007,cartoon -Beatrice Ethel Lithiby,0.61030364,fineart -Grace Pailthorpe,0.61026484,digipa-high-impact -Brian Oldham,0.396231,digipa-low-impact -Android Jones,0.61023116,fareast -François Girardon,0.5830649,fineart -Ib Eisner,0.61016303,digipa-high-impact -Armand Point,0.610156,fineart -Henri Alphonse Barnoin,0.59465057,fineart -Jean Marc Nattier,0.60987425,fineart -Francisco de Holanda,0.6091294,fineart -Marco Mazzoni,0.60970783,fineart -Esaias Boursse,0.6093308,fineart -Alexander Deyneka,0.55000365,fineart -John Totleben,0.60883725,fineart -Al Feldstein,0.6087723,fineart -Adam Hughes,0.60854626,anime -Ernest Zobole,0.6085073,fineart -Alex Gross,0.60837066,digipa-high-impact -George Jamesone,0.6079673,fineart -Frank Lloyd Wright,0.60793245,scribbles -Brooke DiDonato,0.47680336,digipa-med-impact -Hans Gude,0.60780364,fineart -Ethel Schwabacher,0.60748273,fineart -Gladys Kathleen Bell,0.60747695,fineart -Adolf Fényes,0.54192233,fineart -Carel Willink,0.58120143,fineart -George Henry,0.6070727,digipa-high-impact -Ronald Balfour,0.60697085,fineart -Elsie Dalton Hewland,0.6067718,digipa-high-impact -Alex Maleev,0.6067118,fineart -Anish Kapoor,0.6067015,digipa-high-impact -Aleksandr Ivanovich Laktionov,0.606544,fineart -Kim Keever,0.6037775,digipa-high-impact -Aleksi Briclot,0.46056762,fineart -Raymond Leech,0.6062721,fineart -Richard Eurich,0.6062664,fineart -Phil Jimenez,0.60625625,cartoon -Gao Cen,0.60618126,nudity -Mike Deodato,0.6061201,cartoon -Charles Haslewood Shannon,0.6060581,fineart -Alexandre Jacovleff,0.3991747,digipa-low-impact -André Beauneveu,0.584062,fineart -Hiroshi Honda,0.60507596,digipa-high-impact -Charles Joshua Chaplin,0.60498774,fineart -Domenico Zampieri,0.6049726,fineart -Gusukuma Seihō,0.60479784,fareast -Nikolina Petolas,0.46318632,digipa-low-impact -Casey Weldon,0.6047672,cartoon -Elmyr de Hory,0.6046374,fineart -Nan Goldin,0.6046119,digipa-high-impact -Charles McAuley,0.6045995,fineart -Archibald Skirving,0.6044234,fineart -Elizabeth York Brunton,0.6043737,fineart -Dugald Sutherland MacColl,0.6042907,fineart -Titian,0.60426414,fineart -Ignacy Witkiewicz,0.6042259,fineart -Allie Brosh,0.6042061,digipa-high-impact -H.P. Lovecraft,0.6039597,digipa-high-impact -Andrée Ruellan,0.60395086,fineart -Ralph McQuarrie,0.60380936,fineart -Mead Schaeffer,0.6036558,fineart -Henri-Julien Dumont,0.571257,fineart -Kieron Gillen,0.6035093,fineart -Maginel Wright Enright Barney,0.6034306,nudity -Vincent Di Fate,0.6034131,fineart -Briton Rivière,0.6032918,fineart -Hajime Sorayama,0.60325956,nudity -Béla Czóbel,0.6031023,fineart -Edmund Blampied,0.603072,fineart -E. Simms Campbell,0.6030443,fineart -Hisui Sugiura,0.603034,fareast -Alan Davis,0.6029676,fineart -Glen Keane,0.60287905,cartoon -Frank Holl,0.6027312,fineart -Abbott Fuller Graves,0.6025608,fineart -Albert Servaes,0.60250103,black-white -Hovsep Pushman,0.5937487,fineart -Brian M. Viveros,0.60233414,fineart -Charles Fremont Conner,0.6023278,fineart -Francesco Furini,0.6022654,digipa-high-impact -Camille-Pierre Pambu Bodo,0.60191673,fineart -Yasushi Nirasawa,0.6016714,nudity -Charles Uzzell-Edwards,0.6014683,fineart -Abram Efimovich Arkhipov,0.60128385,fineart -Hedda Sterne,0.6011857,digipa-high-impact -Ben Aronson,0.6011548,fineart -Frank Frazetta,0.551121,nudity -Elizabeth Durack,0.6010842,fineart -Ian Miller,0.42153555,fareast -Charlie Bowater,0.4410439,special -Michael Carson,0.60039437,fineart -Walter Langley,0.6002273,fineart -Cornelis Anthonisz,0.6001956,fineart -Dorothy Elizabeth Bradford,0.6001929,fineart -J.C. Leyendecker,0.5791972,fineart -Willem van Haecht,0.59990716,fineart -Anna and Elena Balbusso,0.59955937,digipa-low-impact -Harrison Fisher,0.59952044,fineart -Bill Medcalf,0.59950054,fineart -Edward Arthur Walton,0.59945667,fineart -Alois Arnegger,0.5991994,fineart -Ray Caesar,0.59902894,digipa-high-impact -Karen Wallis,0.5990094,fineart -Emmanuel Shiu,0.51082766,digipa-med-impact -Thomas Struth,0.5988324,digipa-high-impact -Barbara Longhi,0.5985706,fineart -Richard Deacon,0.59851056,fineart -Constantin Hansen,0.5984213,fineart -Harold Shapinsky,0.5984175,fineart -George Dionysus Ehret,0.5983857,fineart -Doug Wildey,0.5983639,digipa-high-impact -Fernand Toussaint,0.5982694,fineart -Horatio Nelson Poole,0.5982614,fineart -Caesar van Everdingen,0.5981566,fineart -Eva Gonzalès,0.5981396,fineart -Franz Vohwinkel,0.5448179,fineart -Margaret Mee,0.5979592,fineart -Francis Focer Brown,0.59779185,fineart -Henry Moore,0.59767926,nudity -Scott Listfield,0.58795893,fineart -Nikolai Ge,0.5973643,fineart -Jacek Yerka,0.58198756,fineart -Margaret Brundage,0.5969077,fineart -JC Leyendecker,0.5620243,fineart -Ben Templesmith,0.5498991,digipa-high-impact -Armin Hansen,0.59669334,anime -Jean-Louis Prevost,0.5966897,fineart -Daphne Allen,0.59666026,fineart -Franz Karl Basler-Kopp,0.59663445,fineart -"Henry Ives Cobb, Jr.",0.596385,fineart -Michael Sowa,0.546285,fineart -Anna Füssli,0.59600973,fineart -György Rózsahegyi,0.59580946,fineart -Luis Royo,0.59566617,fineart -Émile Gallé,0.5955559,fineart -Antonio Mora,0.5334297,digipa-high-impact -Edward P. Beard Jr.,0.59543866,fineart -Jessica Rossier,0.54958373,special -André Thomkins,0.5343785,digipa-high-impact -David Macbeth Sutherland,0.5949968,fineart -Charles Liu,0.5949787,digipa-high-impact -Edi Rama,0.5949226,digipa-high-impact -Jacques Le Moyne,0.5948843,fineart -Egbert van der Poel,0.59488285,fineart -Georg Jensen,0.594782,digipa-high-impact -Anne Sudworth,0.5947539,fineart -Jan Pietersz Saenredam,0.59472525,fineart -Henryk Stażewski,0.5945748,fineart -André François,0.58402044,fineart -Alexander Runciman,0.5944449,digipa-high-impact -Thomas Kinkade,0.594391,fineart -Robert Williams,0.5567989,digipa-high-impact -George Gardner Symons,0.57431924,fineart -D. Alexander Gregory,0.5334464,fineart -Gerald Brom,0.52473724,fineart -Robert Hagan,0.59406,fineart -Ernest Crichlow,0.5940588,fineart -Viviane Sassen,0.5939927,digipa-high-impact -Enrique Simonet,0.5937546,fineart -Esther Blaikie MacKinnon,0.593747,digipa-high-impact -Jeff Kinney,0.59372896,scribbles -Igor Morski,0.5936732,digipa-high-impact -John Currin,0.5936216,fineart -Bob Ringwood,0.5935273,digipa-high-impact -Jordan Grimmer,0.44948143,digipa-low-impact -François Barraud,0.5933471,fineart -Helen Binyon,0.59331006,digipa-high-impact -Brenda Chamberlain,0.5932333,fineart -Candido Bido,0.59310603,fineart -Abraham Storck,0.5929502,fineart -Raphael,0.59278333,fineart -Larry Sultan,0.59273386,digipa-high-impact -Agostino Tassi,0.59265685,fineart -Alexander V. Kuprin,0.5925917,fineart -Frans Koppelaar,0.5658725,fineart -Richard Corben,0.59251785,fineart -David Gilmour Blythe,0.5924247,digipa-high-impact -František Kaván,0.5924211,fineart -Rob Liefeld,0.5921167,fineart -Ernő Rubik,0.5920297,fineart -Byeon Sang-byeok,0.59200096,fareast -Johfra Bosschart,0.5919376,fineart -Emil Lindenfeld,0.5761086,fineart -Howard Mehring,0.5917471,fineart -Gwenda Morgan,0.5915571,digipa-high-impact -Henry Asencio,0.5915404,fineart -"George Barret, Sr.",0.5914306,fineart -Andrew Ferez,0.5911011,fineart -Ed Brubaker,0.5910869,digipa-high-impact -George Reid,0.59095883,digipa-high-impact -Derek Gores,0.51769906,digipa-med-impact -Charles Rollier,0.5539186,fineart -Terry Oakes,0.590443,fineart -Thomas Blackshear,0.5078616,fineart -Albert Benois,0.5902705,nudity -Krenz Cushart,0.59026587,special -Jeff Koons,0.5902637,digipa-high-impact -Akihiko Yoshida,0.5901294,special -Anja Percival,0.45039332,digipa-low-impact -Eduard von Steinle,0.59008586,fineart -Alex Russell Flint,0.5900352,digipa-high-impact -Edward Okuń,0.5897297,fineart -Emma Lampert Cooper,0.5894849,fineart -Stuart Haygarth,0.58132994,digipa-high-impact -George French Angas,0.5434376,fineart -Edmund F. Ward,0.5892848,fineart -Eleanor Vere Boyle,0.58925456,digipa-high-impact -Evelyn Cheston,0.58924586,fineart -Edwin Dickinson,0.58921975,digipa-high-impact -Christophe Vacher,0.47325426,fineart -Anne Dewailly,0.58905107,fineart -Gertrude Greene,0.5862596,digipa-high-impact -Boris Groh,0.5888809,digipa-high-impact -Douglas Smith,0.588804,digipa-high-impact -Ian Hamilton Finlay,0.5887713,fineart -Derek Jarman,0.5887292,digipa-high-impact -Archibald Thorburn,0.5882001,fineart -Gillis d'Hondecoeter,0.58813053,fineart -I Ketut Soki,0.58801544,digipa-high-impact -Alex Schomburg,0.46614102,digipa-low-impact -Bastien L. Deharme,0.583349,special -František Jakub Prokyš,0.58782333,fineart -Jesper Ejsing,0.58782053,fineart -Odd Nerdrum,0.53551745,digipa-high-impact -Tom Lovell,0.5877577,fineart -Ayami Kojima,0.5877416,fineart -Peter Sculthorpe,0.5875696,fineart -Bernard D’Andrea,0.5874042,fineart -Denis Eden,0.58739066,digipa-high-impact -Alfons Walde,0.58728385,fineart -Jovana Rikalo,0.47006977,digipa-low-impact -Franklin Booth,0.5870834,fineart -Mat Collishaw,0.5870676,digipa-high-impact -Joseph Lorusso,0.586858,fineart -Helen Stevenson,0.454647,digipa-low-impact -Delaunay,0.58657396,fineart -H.R. Millar,0.58655745,fineart -E. Charlton Fortune,0.586376,fineart -Alson Skinner Clark,0.58631575,fineart -Stan And Jan Berenstain,0.5862361,digipa-high-impact -Howard Lyon,0.5862271,fineart -John Blanche,0.586182,fineart -Bernardo Cavallino,0.5858575,fineart -Tomasz Alen Kopera,0.5216588,fineart -Peter Gric,0.58583695,fineart -Guo Pei,0.5857794,fareast -James Turrell,0.5853901,digipa-high-impact -Alexandr Averin,0.58533764,fineart -Bertalan Székely,0.5548113,digipa-high-impact -Brothers Hildebrandt,0.5850233,fineart -Ed Roth,0.5849769,digipa-high-impact -Enki Bilal,0.58492255,fineart -Alan Lee,0.5848701,fineart -Charles H. Woodbury,0.5848688,fineart -André Charles Biéler,0.5847876,fineart -Annie Rose Laing,0.5597829,fineart -Matt Fraction,0.58463776,cartoon -Charles Alston,0.58453286,fineart -Frank Xavier Leyendecker,0.545465,fineart -Alfred Richard Gurrey,0.584306,fineart -Dan Mumford,0.5843051,cartoon -Francisco Martín,0.5842005,fineart -Alvaro Siza,0.58406967,digipa-high-impact -Frank J. Girardin,0.5839858,fineart -Henry Carr,0.58397424,digipa-high-impact -Charles Furneaux,0.58394694,fineart -Daniel F. Gerhartz,0.58389103,fineart -Gilberto Soren Zaragoza,0.5448442,fineart -Bart Sears,0.5838427,cartoon -Allison Bechdel,0.58383805,digipa-high-impact -Frank O'Meara,0.5837992,fineart -Charles Codman,0.5836579,fineart -Francisco Zúñiga,0.58359766,fineart -Vladimir Kush,0.49075457,fineart -Arnold Mesches,0.5834257,fineart -Frank McKelvey,0.5831641,fineart -Allen Butler Talcott,0.5830911,fineart -Eric Zener,0.58300316,fineart -Noah Bradley,0.44176096,digipa-low-impact -Robert Childress,0.58289623,fineart -Frances C. Fairman,0.5827239,fineart -Kathryn Morris Trotter,0.465856,digipa-low-impact -Everett Raymond Kinstler,0.5824819,fineart -Edward Mitchell Bannister,0.5804899,fineart -"George Barret, Jr.",0.5823128,fineart -Greg Hildebrandt,0.4271311,fineart -Anka Zhuravleva,0.5822078,digipa-high-impact -Rolf Armstrong,0.58217514,fineart -Eric Wallis,0.58191466,fineart -Clemens Ascher,0.5480207,digipa-high-impact -Hugo Kārlis Grotuss,0.5818766,fineart -Albert Paris Gütersloh,0.5817827,fineart -Hilda May Gordon,0.5817449,fineart -Hendrik Martenszoon Sorgh,0.5817126,fineart -Pipilotti Rist,0.5816868,digipa-high-impact -Hiroyuki Tajima,0.5816242,fareast -Igor Zenin,0.58159757,digipa-high-impact -Genevieve Springston Lynch,0.4979099,digipa-med-impact -Dan Witz,0.44476372,fineart -David Roberts,0.5255326,fineart -Frieke Janssens,0.5706969,digipa-high-impact -Arnold Schoenberg,0.56520367,fineart -Inoue Naohisa,0.5809933,fareast -Elfriede Lohse-Wächtler,0.58097905,fineart -Alex Ross,0.42460668,digipa-low-impact -Robert Irwin,0.58078,c -Charles Angrand,0.58077514,fineart -Anne Nasmyth,0.54221964,fineart -Henri Bellechose,0.5773891,fineart -De Hirsh Margules,0.58059025,fineart -Hiromitsu Takahashi,0.5805599,fareast -Ilya Kuvshinov,0.5805521,special -Cassius Marcellus Coolidge,0.5805516,c -Dorothy Burroughes,0.5804835,fineart -Emanuel de Witte,0.58027405,fineart -George Herbert Baker,0.5799624,digipa-high-impact -Cheng Zhengkui,0.57990086,fareast -Bernard Fleetwood-Walker,0.57987773,digipa-high-impact -Philippe Parreno,0.57985014,digipa-high-impact -Thornton Oakley,0.57969713,fineart -Greg Rutkowski,0.5203395,special -Ike no Taiga,0.5795857,anime -Eduardo Lefebvre Scovell,0.5795808,fineart -Adolfo Müller-Ury,0.57944727,fineart -Patrick Woodroffe,0.5228063,fineart -Wim Crouwel,0.57933235,digipa-high-impact -Colijn de Coter,0.5792779,fineart -François Boquet,0.57924724,fineart -Gerbrand van den Eeckhout,0.57897866,fineart -Eugenio Granell,0.5392264,fineart -Kuang Hong,0.5782304,digipa-high-impact -Justin Gerard,0.46685404,fineart -Tokujin Yoshioka,0.5779153,digipa-high-impact -Alan Bean,0.57788515,fineart -Ernest Biéler,0.5778079,fineart -Martin Deschambault,0.44401115,digipa-low-impact -Anna Boch,0.577735,fineart -Jack Davis,0.5775291,fineart -Félix Labisse,0.5775142,fineart -Greg Simkins,0.5679761,fineart -David Lynch,0.57751054,digipa-low-impact -Eizō Katō,0.5774127,digipa-high-impact -Grethe Jürgens,0.5773412,digipa-high-impact -Heinrich Bichler,0.5770147,fineart -Barbara Nasmyth,0.5446056,fineart -Domenico Induno,0.5583946,fineart -Gustave Baumann,0.5607866,fineart -Mike Mayhew,0.5765857,cartoon -Delmer J. Yoakum,0.576538,fineart -Aykut Aydogdu,0.43111503,digipa-low-impact -George Barker,0.5763551,fineart -Ernő Grünbaum,0.57634187,fineart -Eliseu Visconti,0.5763241,fineart -Esao Andrews,0.5761547,fineart -JennyBird Alcantara,0.49165845,digipa-med-impact -Joan Tuset,0.5761051,fineart -Angela Barrett,0.55976534,digipa-high-impact -Syd Mead,0.5758396,fineart -Ignacio Bazan-Lazcano,0.5757512,fineart -Franciszek Kostrzewski,0.57570386,fineart -Eero Järnefelt,0.57540673,fineart -Loretta Lux,0.56217635,digipa-high-impact -Gaudi,0.57519895,fineart -Charles Gleyre,0.57490873,fineart -Antoine Verney-Carron,0.56386137,fineart -Albert Edelfelt,0.57466495,fineart -Fabian Perez,0.57444525,fineart -Kevin Sloan,0.5737548,fineart -Stanislav Poltavsky,0.57434607,fineart -Abraham Hondius,0.574326,fineart -Tadao Ando,0.57429105,fareast -Fyodor Slavyansky,0.49796474,digipa-med-impact -David Brewster,0.57385933,digipa-high-impact -Cliff Chiang,0.57375133,digipa-high-impact -Drew Struzan,0.5317983,digipa-high-impact -Henry O. Tanner,0.5736586,fineart -Alberto Sughi,0.5736495,fineart -Albert J. Welti,0.5736257,fineart -Charles Mahoney,0.5735923,digipa-high-impact -Exekias,0.5734506,fineart -Felipe Seade,0.57342744,digipa-high-impact -Henriette Wyeth,0.57330644,digipa-high-impact -Harold Sandys Williamson,0.5443646,fineart -Eddie Campbell,0.57329535,digipa-high-impact -Gao Fenghan,0.5732926,fareast -Cynthia Sheppard,0.51099646,fineart -Henriette Grindat,0.573179,fineart -Yasutomo Oka,0.5731342,fareast -Celia Frances Bedford,0.57313216,fineart -Les Edwards,0.42068473,fineart -Edwin Deakin,0.5031717,fineart -Eero Saarinen,0.5725142,digipa-high-impact -Franciszek Smuglewicz,0.5722554,fineart -Doris Blair,0.57221186,fineart -Seb Mckinnon,0.51721895,digipa-med-impact -Gregorio Lazzarini,0.57204294,fineart -Gerard Sekoto,0.5719927,fineart -Francis Ernest Jackson,0.5506009,fineart -Simon Birch,0.57171595,digipa-high-impact -Bayard Wu,0.57171166,fineart -François Clouet,0.57162094,fineart -Christopher Wren,0.5715372,fineart -Evgeny Lushpin,0.5714827,special -Art Green,0.5714495,digipa-high-impact -Amy Judd,0.57142305,digipa-high-impact -Art Brenner,0.42619684,digipa-low-impact -Travis Louie,0.43916368,digipa-low-impact -James Jean,0.5457318,digipa-high-impact -Ewald Rübsamen,0.57083976,fineart -Donato Giancola,0.57052535,fineart -Carl Arnold Gonzenbach,0.5703996,fineart -Bastien Lecouffe-Deharme,0.5201288,fineart -Howard Chandler Christy,0.5702813,nudity -Dean Cornwell,0.56977296,fineart -Don Maitz,0.4743015,fineart -James Montgomery Flagg,0.56974065,fineart -Andreas Levers,0.42125136,digipa-low-impact -Edgar Schofield Baum,0.56965977,fineart -Alan Parry,0.5694952,digipa-high-impact -An Zhengwen,0.56942475,fareast -Alayna Lemmer,0.48293802,fineart -Edward Marshall Boehm,0.5530143,fineart -Henri Biva,0.54013556,nudity -Fiona Rae,0.4646715,digipa-low-impact -Elizabeth Jane Lloyd,0.5688463,digipa-high-impact -Franklin Carmichael,0.5687844,digipa-high-impact -Dionisius,0.56875896,fineart -Edwin Georgi,0.56868523,fineart -Jenny Saville,0.5686633,fineart -Ernest Hébert,0.56859314,fineart -Stephan Martiniere,0.56856346,digipa-high-impact -Huang Binhong,0.56841767,fineart -August Lemmer,0.5683548,fineart -Camille Bouvagne,0.5678048,fineart -Olga Skomorokhova,0.39401102,digipa-low-impact -Sacha Goldberger,0.5675477,digipa-high-impact -Hilda Annetta Walker,0.5675261,digipa-high-impact -Harvey Pratt,0.51314723,digipa-med-impact -Jean Bourdichon,0.5670543,fineart -Noriyoshi Ohrai,0.56690073,fineart -Kadir Nelson,0.5669006,n -Ilya Ostroukhov,0.5668801,fineart -Eugène Brands,0.56681967,fineart -Achille Leonardi,0.56674325,fineart -Franz Cižek,0.56670356,fineart -George Paul Chalmers,0.5665988,digipa-high-impact -Serge Marshennikov,0.5665971,digipa-high-impact -Mike Worrall,0.56641084,fineart -Dirck van Delen,0.5661764,fineart -Peter Andrew Jones,0.5661655,fineart -Rafael Albuquerque,0.56541103,fineart -Daniel Buren,0.5654043,fineart -Giuseppe Grisoni,0.5432699,fineart -George Fiddes Watt,0.55861616,fineart -Stan Lee,0.5651268,digipa-high-impact -Dorning Rasbotham,0.56511617,fineart -Albert Lynch,0.56497896,fineart -Lorenz Hideyoshi,0.56494075,fineart -Fenghua Zhong,0.56492203,fareast -Caroline Lucy Scott,0.49190843,digipa-med-impact -Victoria Crowe,0.5647996,digipa-high-impact -Hasegawa Settan,0.5647092,fareast -Dennis H. Farber,0.56453323,digipa-high-impact -Dick Bickenbach,0.5644289,fineart -Art Frahm,0.56439924,fineart -Edith Edmonds,0.5643151,fineart -Alfred Heber Hutty,0.56419206,fineart -Henry Tonks,0.56410825,fineart -Peter Howson,0.5640759,fineart -Albert Dorne,0.56395364,fineart -Arthur Adams,0.5639404,fineart -Bernt Tunold,0.56383425,digipa-high-impact -Gianluca Foli,0.5637317,digipa-high-impact -Vittorio Matteo Corcos,0.5636767,fineart -Béla Iványi-Grünwald,0.56355745,nudity -Feng Zhu,0.5634973,fineart -Sam Kieth,0.47251505,digipa-low-impact -Charles Crodel,0.5633834,fineart -Elsie Henderson,0.56310076,digipa-high-impact -George Earl Ortman,0.56295705,fineart -Tari Márk Dávid,0.562937,fineart -Betty Merken,0.56281745,digipa-high-impact -Cecile Walton,0.46672013,digipa-low-impact -Bracha L. Ettinger,0.56237936,fineart -Ken Fairclough,0.56230986,digipa-high-impact -Phil Koch,0.56224954,digipa-high-impact -George Pirie,0.56213045,digipa-high-impact -Chad Knight,0.56194013,digipa-high-impact -Béla Kondor,0.5427164,digipa-high-impact -Barclay Shaw,0.53689134,digipa-high-impact -Tim Hildebrandt,0.47194147,fineart -Hermann Rüdisühli,0.56104004,digipa-high-impact -Ian McQue,0.5342066,digipa-high-impact -Yanjun Cheng,0.5607171,fineart -Heinrich Hofmann,0.56060636,fineart -Henry Raleigh,0.5605958,fineart -Ernest Buckmaster,0.5605704,fineart -Charles Ricketts,0.56055415,fineart -Juergen Teller,0.56051147,digipa-high-impact -Auguste Mambour,0.5604873,fineart -Sean Yoro,0.5601486,digipa-high-impact -Sheilah Beckett,0.55995446,digipa-high-impact -Eugene Tertychnyi,0.5598978,fineart -Dr. Seuss,0.5597466,c -Adolf Wölfli,0.5372333,digipa-high-impact -Enrique Tábara,0.559323,fineart -Dionisio Baixeras Verdaguer,0.5590695,fineart -Aleksander Gierymski,0.5590013,fineart -Augustus Dunbier,0.55872476,fineart -Adolf Born,0.55848217,fineart -Chris Turnham,0.5584234,digipa-high-impact -James C Christensen,0.55837405,fineart -Daphne Fedarb,0.5582459,digipa-high-impact -Andre Kohn,0.5581832,special -Ron Mueck,0.5581811,nudity -Glenn Fabry,0.55786383,fineart -Elizabeth Polunin,0.5578102,digipa-high-impact -Charles S. Kaelin,0.5577954,fineart -Arthur Radebaugh,0.5577016,fineart -Ai Yazawa,0.55768114,fareast -Charles Roka,0.55762553,fineart -Ai Weiwei,0.5576034,digipa-high-impact -Dorothy Bradford,0.55760014,digipa-high-impact -Alfred Leslie,0.557555,fineart -Heinrich Herzig,0.5574423,fineart -Eliot Hodgkin,0.55740607,digipa-high-impact -Albert Kotin,0.55737317,fineart -Carlo Carlone,0.55729353,fineart -Chen Rong,0.5571221,fineart -Ikuo Hirayama,0.5570225,digipa-high-impact -Edward Corbett,0.55701995,nudity -Eugeniusz Żak,0.556925,nudity -Ettore Tito,0.556875,fineart -Helene Knoop,0.5567731,fineart -Amanda Sage,0.37731662,fareast -Annick Bouvattier,0.54647046,fineart -Harvey Dunn,0.55663586,fineart -Hans Sandreuter,0.5562575,digipa-high-impact -Ruan Jia,0.5398549,special -Anton Räderscheidt,0.55618906,fineart -Tyler Shields,0.4081434,digipa-low-impact -Darek Zabrocki,0.49975997,digipa-med-impact -Frank Montague Moore,0.5556432,fineart -Greg Staples,0.5555332,fineart -Endre Bálint,0.5553731,fineart -Augustus Vincent Tack,0.5136602,fineart -Marc Simonetti,0.48602036,fineart -Carlo Randanini,0.55493265,digipa-high-impact -Diego Dayer,0.5549119,fineart -Kelly Freas,0.55476534,fineart -Thomas Saliot,0.5139967,digipa-med-impact -Gijsbert d'Hondecoeter,0.55455256,fineart -Walter Kim,0.554521,digipa-high-impact -Francesco Cozza,0.5155097,digipa-med-impact -Bill Watterson,0.5542879,digipa-high-impact -Mark Keathley,0.4824056,fineart -Béni Ferenczy,0.55405354,digipa-high-impact -Amadou Opa Bathily,0.5536976,n -Giuseppe Antonio Petrini,0.55340284,fineart -Enzo Cucchi,0.55331933,digipa-high-impact -Adolf Schrödter,0.55316544,fineart -George Benjamin Luks,0.548566,fineart -Glenys Cour,0.55304,digipa-high-impact -Andrew Robertson,0.5529603,digipa-high-impact -Claude Rogers,0.55272067,digipa-high-impact -Alexandre Antigna,0.5526737,fineart -Aimé Barraud,0.55265915,digipa-high-impact -György Vastagh,0.55258965,fineart -Bruce Nauman,0.55257386,digipa-high-impact -Benjamin Block,0.55251944,digipa-high-impact -Gonzalo Endara Crow,0.552346,digipa-high-impact -Dirck de Bray,0.55221736,fineart -Gerald Kelley,0.5521059,digipa-high-impact -Dave Gibbons,0.5520954,digipa-high-impact -Béla Nagy Abodi,0.5520624,digipa-high-impact -Faith 47,0.5517006,digipa-high-impact -Anna Razumovskaya,0.5229187,digipa-med-impact -Archibald Robertson,0.55129635,digipa-high-impact -Louise Dahl-Wolfe,0.55120385,digipa-high-impact -Simon Bisley,0.55119276,digipa-high-impact -Eric Fischl,0.55107886,fineart -Hu Zaobin,0.5510481,fareast -Béla Pállik,0.5507963,digipa-high-impact -Eugene J. Martin,0.55078864,fineart -Friedrich Gauermann,0.55063415,fineart -Fritz Baumann,0.5341434,fineart -Michal Lisowski,0.5505639,fineart -Paolo Roversi,0.5503342,digipa-high-impact -Andrew Atroshenko,0.55009747,fineart -Gyula Derkovits,0.5500315,fineart -Hugh Adam Crawford,0.55000615,digipa-high-impact -Béla Apáti Abkarovics,0.5499799,digipa-high-impact -Paul Chadeisson,0.389151,digipa-low-impact -Aurél Bernáth,0.54968774,fineart -Albert Henry Krehbiel,0.54952574,fineart -Piet Hein Eek,0.54918796,digipa-high-impact -Yoshitaka Amano,0.5491855,fareast -Antonio Rotta,0.54909515,fineart -Józef Mehoffer,0.50760424,fineart -Donald Sherwood,0.5490415,digipa-high-impact -Catrin G Grosse,0.5489286,digipa-high-impact -Arthur Webster Emerson,0.5478842,fineart -Incarcerated Jerkfaces,0.5488423,digipa-high-impact -Emanuel Büchel,0.5487217,fineart -Andrew Loomis,0.54854584,fineart -Charles Hopkinson,0.54853606,fineart -Gabor Szikszai,0.5485203,digipa-high-impact -Archibald Standish Hartrick,0.54850936,digipa-high-impact -Aleksander Orłowski,0.546705,nudity -Hans Hinterreiter,0.5483628,fineart -Fred Williams,0.54544824,fineart -Fred A. Precht,0.5481606,fineart -Camille Souter,0.5213742,fineart -Emil Fuchs,0.54807395,fineart -Francesco Bonsignori,0.5478936,fineart -H. R. (Hans Ruedi) Giger,0.547799,fineart -Harriet Zeitlin,0.5477388,digipa-high-impact -Christian Jane Fergusson,0.5396168,fineart -Edward Kemble,0.5476892,fineart -Bernard Aubertin,0.5475396,fineart -Augustyn Mirys,0.5474162,fineart -Alejandro Burdisio,0.47482288,special -Erin Hanson,0.4343264,digipa-low-impact -Amalia Lindegren,0.5471987,digipa-high-impact -Alberto Seveso,0.47735062,fineart -Bartholomeus Strobel,0.54703736,fineart -Jim Davis,0.54703003,digipa-high-impact -Antony Gormley,0.54696125,digipa-high-impact -Charles Marion Russell,0.54696095,fineart -George B. Sutherland,0.5467901,fineart -Almada Negreiros,0.54670584,fineart -Edward Armitage,0.54358315,fineart -Bruno Walpoth,0.546167,digipa-high-impact -Richard Hamilton,0.5461275,nudity -Charles Harold Davis,0.5460415,digipa-high-impact -Fernand Verhaegen,0.54601514,fineart -Bernard Meninsky,0.5302034,digipa-high-impact -Fede Galizia,0.5456873,digipa-high-impact -Alfred Kelsner,0.5455753,nudity -Fritz Puempin,0.5452847,fineart -Alfred Charles Parker,0.54521024,fineart -Ahmed Yacoubi,0.544767,digipa-high-impact -Arthur B. Carles,0.54447794,fineart -Alice Prin,0.54435575,digipa-high-impact -Carl Gustaf Pilo,0.5443212,digipa-high-impact -Ross Tran,0.5259248,special -Hideyuki Kikuchi,0.544193,fareast -Art Fitzpatrick,0.49847245,fineart -Cherryl Fountain,0.5440454,fineart -Skottie Young,0.5440119,cartoon -NC Wyeth,0.54382974,digipa-high-impact -Rudolf Freund,0.5437342,fineart -Mort Kunstler,0.5433619,digipa-high-impact -Ben Goossens,0.53002644,digipa-high-impact -Andreas Rocha,0.49621177,special -Gérard Ernest Schneider,0.5429964,fineart -Francesco Filippini,0.5429598,digipa-high-impact -Alejandro Jodorowsky,0.5429065,digipa-high-impact -Friedrich Traffelet,0.5428817,fineart -Honor C. Appleton,0.5428735,digipa-high-impact -Jason A. Engle,0.542821,fineart -Henry Otto Wix,0.54271996,fineart -Gregory Manchess,0.54270375,fineart -Ann Stookey,0.54269934,digipa-high-impact -Henryk Rodakowski,0.542589,fineart -Albert Welti,0.5425134,digipa-high-impact -Gerard Houckgeest,0.5424413,digipa-high-impact -Dorothy Hood,0.54226196,digipa-high-impact -Frank Schoonover,0.51056194,fineart -Erlund Hudson,0.5422107,digipa-high-impact -Alexander Litovchenko,0.54210097,fineart -Sakai Hōitsu,0.5420294,digipa-high-impact -Benito Quinquela Martín,0.54194224,fineart -David Watson Stevenson,0.54191554,fineart -Ann Thetis Blacker,0.5416629,digipa-high-impact -Frank DuMond,0.51004076,digipa-med-impact -David Dougal Williams,0.5410126,digipa-high-impact -Robert Mcginnis,0.54098356,fineart -Ernest Briggs,0.5408636,fineart -Ferenc Joachim,0.5408625,fineart -Carlos Saenz de Tejada,0.47332364,digipa-low-impact -David Burton-Richardson,0.49659324,digipa-med-impact -Ernest Heber Thompson,0.54039246,digipa-high-impact -Albert Bertelsen,0.54038215,nudity -Giorgio Giulio Clovio,0.5403708,fineart -Eugene Leroy,0.54019785,digipa-high-impact -Anna Findlay,0.54018176,digipa-high-impact -Roy Gjertson,0.54012,digipa-high-impact -Charmion von Wiegand,0.5400893,fineart -Arnold Bronckhorst,0.526247,fineart -Boris Vallejo,0.487253,fineart -Adélaïde Victoire Hall,0.539939,fineart -Earl Norem,0.5398575,fineart -Sanford Kossin,0.53977877,digipa-high-impact -Aert de Gelder,0.519166,digipa-med-impact -Carl Eugen Keel,0.539739,digipa-high-impact -Francis Bourgeois,0.5397272,digipa-high-impact -Bojan Jevtic,0.41141546,fineart -Edward Avedisian,0.5393925,fineart -Gao Xiang,0.5392419,fareast -Charles Hinman,0.53911865,digipa-high-impact -Frits Van den Berghe,0.53896487,fineart -Carlo Martini,0.5384833,digipa-high-impact -Elina Karimova,0.5384318,digipa-high-impact -Anto Carte,0.4708289,digipa-low-impact -Andrey Yefimovich Martynov,0.537721,fineart -Frances Jetter,0.5376904,fineart -Yuri Ivanovich Pimenov,0.5342793,fineart -Gaston Anglade,0.537608,digipa-high-impact -Albert Swinden,0.5375844,fineart -Bob Byerley,0.5375774,fineart -A.B. Frost,0.5375025,fineart -Jaya Suberg,0.5372893,digipa-high-impact -Josh Keyes,0.53654516,digipa-high-impact -Juliana Huxtable,0.5364195,n -Everett Warner,0.53641814,digipa-high-impact -Hugh Kretschmer,0.45171157,digipa-low-impact -Arnold Blanch,0.535774,fineart -Ryan McGinley,0.53572595,digipa-high-impact -Alfons Karpiński,0.53564656,fineart -George Aleef,0.5355317,digipa-high-impact -Hal Foster,0.5351446,fineart -Stuart Immonen,0.53501946,digipa-high-impact -Craig Thompson,0.5346844,digipa-high-impact -Bartolomeo Vivarini,0.53465015,fineart -Hermann Feierabend,0.5346168,digipa-high-impact -Antonio Donghi,0.4610982,digipa-low-impact -Adonna Khare,0.4858036,digipa-med-impact -James Stokoe,0.5015107,digipa-med-impact -Agustín Fernández,0.53403986,fineart -Germán Londoño,0.5338712,fineart -Emmanuelle Moureaux,0.5335641,digipa-high-impact -Conrad Marca-Relli,0.5148334,digipa-med-impact -Gyula Batthyány,0.5332407,fineart -Francesco Raibolini,0.53314835,fineart -Apelles,0.5166026,fineart -Marat Latypov,0.45811993,fineart -Andrei Markin,0.5328752,fineart -Einar Hakonarson,0.5328311,digipa-high-impact -Beatrice Huntington,0.5328165,digipa-high-impact -Coppo di Marcovaldo,0.5327443,fineart -Gregorio Prestopino,0.53250784,fineart -A.D.M. Cooper,0.53244877,digipa-high-impact -Horatio McCulloch,0.53244334,digipa-high-impact -Wes Anderson,0.5318741,digipa-high-impact -Moebius,0.53178746,digipa-high-impact -Gerard Soest,0.53160626,fineart -Charles Ellison,0.53152347,digipa-high-impact -Wojciech Ostrycharz,0.5314213,fineart -Doug Chiang,0.5313724,fineart -Anne Savage,0.5310638,digipa-high-impact -Cor Melchers,0.53099334,fineart -Gordon Browne,0.5308195,digipa-high-impact -Augustus Earle,0.49196815,fineart -Carlos Francisco Chang Marín,0.5304734,fineart -Larry Elmore,0.53032553,fineart -Adolf Hölzel,0.5303149,fineart -David Ligare,0.5301894,fineart -Jan Luyken,0.52985555,fineart -Earle Bergey,0.5298525,fineart -David Ramsay Hay,0.52974963,digipa-high-impact -Alfred East,0.5296565,digipa-high-impact -A. R. Middleton Todd,0.50988734,fineart -Giorgio De Vincenzi,0.5291678,fineart -Hugh William Williams,0.5291014,digipa-high-impact -Erwin Bowien,0.52895796,digipa-high-impact -Victor Adame Minguez,0.5288686,fineart -Yoji Shinkawa,0.5287015,anime -Clara Weaver Parrish,0.5284487,digipa-high-impact -Albert Eckhout,0.5284096,fineart -Dorothy Coke,0.5282345,digipa-high-impact -Jerzy Duda-Gracz,0.5279943,digipa-high-impact -Byron Galvez,0.39178842,fareast -Alson S. Clark,0.5278568,digipa-high-impact -Adolf Ulric Wertmüller,0.5278296,digipa-high-impact -Bruce Coville,0.5277226,digipa-high-impact -Gong Kai,0.5276811,digipa-high-impact -Andréi Arinouchkine,0.52763486,digipa-high-impact -Florence Engelbach,0.5273161,digipa-high-impact -Brian Froud,0.5270276,fineart -Charles Thomson,0.5270127,digipa-high-impact -Bessie Wheeler,0.5269164,digipa-high-impact -Anton Lehmden,0.5268611,fineart -Emilia Wilk,0.5264961,fineart -Carl Eytel,0.52646196,digipa-high-impact -Alfred Janes,0.5264481,digipa-high-impact -Julie Bell,0.49962538,fineart -Eugenio de Arriba,0.52613926,digipa-high-impact -Samuel and Joseph Newsom,0.52595663,digipa-high-impact -Hans Falk,0.52588874,digipa-high-impact -Guillermo del Toro,0.52565175,digipa-high-impact -Félix Arauz,0.52555984,digipa-high-impact -Gyula Basch,0.52524436,digipa-high-impact -Haroon Mirza,0.5252279,digipa-high-impact -Du Jin,0.5249934,digipa-med-impact -Harry Shoulberg,0.5249456,digipa-med-impact -Arie Smit,0.5249027,fineart -Ahmed Karahisari,0.4259451,digipa-low-impact -Brian and Wendy Froud,0.5246335,fineart -E. William Gollings,0.52461207,digipa-med-impact -Bo Bartlett,0.51341593,digipa-med-impact -Hans Burgkmair,0.52416867,digipa-med-impact -David Macaulay,0.5241233,digipa-med-impact -Benedetto Caliari,0.52370214,digipa-med-impact -Eliott Lilly,0.5235398,digipa-med-impact -Vincent Tanguay,0.48578292,digipa-med-impact -Ada Hill Walker,0.52207166,fineart -Christopher Wood,0.49360397,digipa-med-impact -Kris Kuksi,0.43938053,digipa-low-impact -Chen Yifei,0.5217867,fineart -Margaux Valonia,0.5217782,digipa-med-impact -Antoni Pitxot,0.40582713,digipa-low-impact -Jhonen Vasquez,0.5216471,digipa-med-impact -Emilio Grau Sala,0.52156484,fineart -Henry B. Christian,0.52153796,fineart -Jacques Nathan-Garamond,0.52144086,digipa-med-impact -Eddie Mendoza,0.4949638,digipa-med-impact -Grzegorz Rutkowski,0.48906532,special -Beeple,0.40085253,digipa-low-impact -Giorgio Cavallon,0.5209209,digipa-med-impact -Godfrey Blow,0.52062386,digipa-med-impact -Gabriel Dawe,0.5204431,fineart -Emile Lahner,0.5202367,digipa-med-impact -Steve Dillon,0.5201676,digipa-med-impact -Lee Quinones,0.4626683,digipa-low-impact -Hale Woodruff,0.52000225,digipa-med-impact -Tom Hammick,0.5032626,digipa-med-impact -Hamilton Sloan,0.5197798,digipa-med-impact -Caesar Andrade Faini,0.51971483,digipa-med-impact -Sam Spratt,0.48991,digipa-med-impact -Chris Cold,0.4753577,fineart -Alejandro Obregón,0.5190562,digipa-med-impact -Dan Flavin,0.51901346,digipa-med-impact -Arthur Sarnoff,0.5189428,fineart -Elenore Abbott,0.5187141,digipa-med-impact -Andrea Kowch,0.51822996,digipa-med-impact -Demetrios Farmakopoulos,0.5181248,digipa-med-impact -Alexis Grimou,0.41958088,digipa-low-impact -Lesley Vance,0.5177536,digipa-med-impact -Gyula Aggházy,0.517747,fineart -Georgina Hunt,0.46105456,digipa-low-impact -Christian W. Staudinger,0.4684662,digipa-low-impact -Abraham Begeyn,0.5172538,digipa-med-impact -Charles Mozley,0.5171356,digipa-med-impact -Elias Ravanetti,0.38719344,digipa-low-impact -Herman van Swanevelt,0.5168748,digipa-med-impact -David Paton,0.4842217,digipa-med-impact -Hans Werner Schmidt,0.51671976,digipa-med-impact -Bob Ross,0.51628315,fineart -Sou Fujimoto,0.5162528,fareast -Balcomb Greene,0.5162045,digipa-med-impact -Glen Angus,0.51609933,digipa-med-impact -Buckminster Fuller,0.51607454,digipa-med-impact -Andrei Ryabushkin,0.5158933,fineart -Almeida Júnior,0.515856,digipa-med-impact -Tim White,0.4182697,digipa-low-impact -Hans Beat Wieland,0.51553553,digipa-med-impact -Jakub Różalski,0.5154904,digipa-med-impact -John Whitcomb,0.51523805,digipa-med-impact -Dorothy King,0.5150925,digipa-med-impact -Richard S. Johnson,0.51500344,fineart -Aniello Falcone,0.51475304,digipa-med-impact -Henning Jakob Henrik Lund,0.5147134,c -Robert M Cunningham,0.5144858,digipa-med-impact -Nick Knight,0.51447505,digipa-med-impact -David Chipperfield,0.51424,digipa-med-impact -Bartolomeo Cesi,0.5136737,digipa-med-impact -Bettina Heinen-Ayech,0.51334465,digipa-med-impact -Annabel Kidston,0.51327646,digipa-med-impact -Charles Schridde,0.51308405,digipa-med-impact -Samuel Earp,0.51305825,digipa-med-impact -Eugene Montgomery,0.5128343,digipa-med-impact -Alfred Parsons,0.5127445,digipa-med-impact -Anton Möller,0.5127209,digipa-med-impact -Craig Davison,0.499598,special -Cricorps Grégoire,0.51267076,fineart -Celia Fiennes,0.51266706,digipa-med-impact -Raymond Swanland,0.41350424,fineart -Howard Knotts,0.5122062,digipa-med-impact -Helmut Federle,0.51201206,digipa-med-impact -Tyler Edlin,0.44028252,digipa-high-impact -Elwood H. Smith,0.5119027,digipa-med-impact -Ralph Horsley,0.51142794,fineart -Alexander Ivanov,0.4539051,digipa-low-impact -Cedric Peyravernay,0.4200587,digipa-low-impact -Annabel Eyres,0.51136214,digipa-med-impact -Zack Snyder,0.51129746,digipa-med-impact -Gentile Bellini,0.511102,digipa-med-impact -Giovanni Pelliccioli,0.4868688,digipa-med-impact -Fikret Muallâ Saygı,0.510694,digipa-med-impact -Bauhaus,0.43454266,digipa-low-impact -Charles Williams,0.510406,digipa-med-impact -Georg Arnold-Graboné,0.5103381,digipa-med-impact -Fedot Sychkov,0.47935224,digipa-med-impact -Alberto Magnelli,0.5103212,digipa-med-impact -Aloysius O'Kelly,0.5102891,digipa-med-impact -Alexander McQueen,0.5101986,digipa-med-impact -Cam Sykes,0.510071,digipa-med-impact -George Lucas,0.510038,digipa-med-impact -Eglon van der Neer,0.5099339,digipa-med-impact -Christian August Lorentzen,0.50989646,digipa-med-impact -Eleanor Best,0.50966686,digipa-med-impact -Terry Redlin,0.474244,fineart -Ken Kelly,0.4304738,fineart -David Eugene Henry,0.48173362,fineart -Shin Jeongho,0.5092497,fareast -Flora Borsi,0.5091922,digipa-med-impact -Berndnaut Smilde,0.50864,digipa-med-impact -Art of Brom,0.45828784,fineart -Ernő Tibor,0.50851977,digipa-med-impact -Ancell Stronach,0.5084514,digipa-med-impact -Helen Thomas Dranga,0.45412368,digipa-low-impact -Anita Malfatti,0.5080986,digipa-med-impact -Arnold Brügger,0.5080749,digipa-med-impact -Edward Ben Avram,0.50778764,digipa-med-impact -Antonio Ciseri,0.5073538,fineart -Alyssa Monks,0.50734174,digipa-med-impact -Chen Zhen,0.5071876,digipa-med-impact -Francis Helps,0.50707847,digipa-med-impact -Georg Karl Pfahler,0.50700235,digipa-med-impact -Henry Woods,0.506811,digipa-med-impact -Barbara Greg,0.50674164,digipa-med-impact -Guan Daosheng,0.506712,fareast -Guy Billout,0.5064906,digipa-med-impact -Basuki Abdullah,0.50613165,digipa-med-impact -Thomas Visscher,0.5059943,digipa-med-impact -Edward Simmons,0.50598735,digipa-med-impact -Arabella Rankin,0.50572735,digipa-med-impact -Lady Pink,0.5056634,digipa-high-impact -Christopher Williams,0.5052288,digipa-med-impact -Fuyuko Matsui,0.5051116,fareast -Edward Baird,0.5049874,digipa-med-impact -Georges Stein,0.5049069,digipa-med-impact -Alex Alemany,0.43974748,digipa-low-impact -Emanuel Schongut,0.5047326,digipa-med-impact -Hans Bol,0.5045265,digipa-med-impact -Kurzgesagt,0.5043725,digipa-med-impact -Harald Giersing,0.50410193,digipa-med-impact -Antonín Slavíček,0.5040368,fineart -Carl Rahl,0.5040115,digipa-med-impact -Etienne Delessert,0.5037818,fineart -Americo Makk,0.5034161,digipa-med-impact -Fernand Pelez,0.5027561,digipa-med-impact -Alexey Merinov,0.4469615,digipa-low-impact -Caspar Netscher,0.5019529,digipa-med-impact -Walt Disney,0.50178146,digipa-med-impact -Qian Xuan,0.50150526,fareast -Geoffrey Dyer,0.50120556,digipa-med-impact -Andre Norton,0.5007602,digipa-med-impact -Daphne McClure,0.5007391,digipa-med-impact -Dieric Bouts,0.5005882,fineart -Aguri Uchida,0.5005107,fareast -Hugo Scheiber,0.50004864,digipa-med-impact -Kenne Gregoire,0.46421963,digipa-low-impact -Wolfgang Tillmans,0.4999767,fineart -Carl-Henning Pedersen,0.4998986,digipa-med-impact -Alison Debenham,0.4998683,digipa-med-impact -Eppo Doeve,0.49975222,digipa-med-impact -Christen Købke,0.49961317,digipa-med-impact -Aron Demetz,0.49895018,digipa-med-impact -Alesso Baldovinetti,0.49849576,digipa-med-impact -Jimmy Lawlor,0.4475271,fineart -Carl Walter Liner,0.49826378,fineart -Gwenny Griffiths,0.45598924,digipa-low-impact -David Cooke Gibson,0.4976222,digipa-med-impact -Howard Butterworth,0.4974621,digipa-med-impact -Bob Thompson,0.49743804,fineart -Enguerrand Quarton,0.49711192,fineart -Abdel Hadi Al Gazzar,0.49631482,digipa-med-impact -Gu Zhengyi,0.49629828,digipa-med-impact -Aleksander Kotsis,0.4953621,digipa-med-impact -Alexander Sharpe Ross,0.49519226,digipa-med-impact -Carlos Enríquez Gómez,0.49494863,digipa-med-impact -Abed Abdi,0.4948855,digipa-med-impact -Elaine Duillo,0.49474388,digipa-med-impact -Anne Said,0.49473995,digipa-med-impact -Istvan Banyai,0.4947369,digipa-med-impact -Bouchta El Hayani,0.49455142,digipa-med-impact -Chinwe Chukwuogo-Roy,0.49445248,n -George Claessen,0.49412063,digipa-med-impact -Axel Törneman,0.49401706,digipa-med-impact -Avigdor Arikha,0.49384058,digipa-med-impact -Gloria Stoll Karn,0.4937976,digipa-med-impact -Alfredo Volpi,0.49367586,digipa-med-impact -Raffaello Sanizo,0.49365884,digipa-med-impact -Jeff Easley,0.49344411,digipa-med-impact -Aileen Eagleton,0.49318358,digipa-med-impact -Gaetano Sabatini,0.49307147,digipa-med-impact -Bertalan Pór,0.4930132,digipa-med-impact -Alfred Jensen,0.49291304,digipa-med-impact -Huang Guangjian,0.49286693,fareast -Emil Ferris,0.49282396,digipa-med-impact -Derek Chittock,0.492694,digipa-med-impact -Alonso Vázquez,0.49205148,digipa-med-impact -Kelly Sue Deconnick,0.4919476,digipa-med-impact -Clive Madgwick,0.4749857,fineart -Edward George Handel Lucas,0.49166748,digipa-med-impact -Dorothea Braby,0.49161923,digipa-med-impact -Sangyeob Park,0.49150884,fareast -Heinz Edelman,0.49140438,digipa-med-impact -Mark Seliger,0.4912073,digipa-med-impact -Camilo Egas,0.4586727,digipa-low-impact -Craig Mullins,0.49085408,fineart -Dong Kingman,0.49063343,digipa-med-impact -Douglas Robertson Bisset,0.49031347,digipa-med-impact -Blek Le Rat,0.49008566,digipa-med-impact -Anton Ažbe,0.48984748,fineart -Olafur Eliasson,0.48971075,digipa-med-impact -Elinor Proby Adams,0.48967826,digipa-med-impact -Cándido López,0.48915705,digipa-med-impact -D. Howard Hitchcock,0.48902267,digipa-med-impact -Cheng Jiasui,0.48889247,fareast -Jean Nouvel,0.4888183,digipa-med-impact -Bill Gekas,0.48848945,digipa-med-impact -Hermione Hammond,0.48845994,digipa-med-impact -Fernando Gerassi,0.48841453,digipa-med-impact -Frank Barrington Craig,0.4883762,digipa-med-impact -A. B. Jackson,0.4883623,digipa-med-impact -Bernie D’Andrea,0.48813275,digipa-med-impact -Clarice Beckett,0.487809,digipa-med-impact -Dosso Dossi,0.48775777,digipa-med-impact -Donald Roller Wilson,0.48767656,digipa-med-impact -Ernest William Christmas,0.4876317,digipa-med-impact -Aleksandr Gerasimov,0.48736423,digipa-med-impact -Edward Clark,0.48703307,digipa-med-impact -Georg Schrimpf,0.48697302,digipa-med-impact -John Wilhelm,0.48696536,digipa-med-impact -Aries Moross,0.4863676,digipa-med-impact -Bill Lewis,0.48635158,digipa-med-impact -Huang Ji,0.48611963,fareast -F. Scott Hess,0.43634564,fineart -Gao Qipei,0.4860631,fareast -Albert Tucker,0.4854299,digipa-med-impact -Barbara Balmer,0.48528513,fineart -Anne Ryan,0.48511976,digipa-med-impact -Helen Edwards,0.48484707,digipa-med-impact -Alexander Bogen,0.48421195,digipa-med-impact -David Annand,0.48418126,digipa-med-impact -Du Qiong,0.48414314,fareast -Fred Cress,0.4837878,digipa-med-impact -David B. Mattingly,0.48370445,digipa-med-impact -Hristofor Žefarović,0.4837008,digipa-med-impact -Wim Wenders,0.44484183,digipa-low-impact -Alexander Fedosav,0.48360944,digipa-med-impact -Anne Rigney,0.48357943,digipa-med-impact -Bertalan Karlovszky,0.48338628,digipa-med-impact -George Frederick Harris,0.4833259,fineart -Toshiharu Mizutani,0.48315164,fareast -David McClellan,0.39739317,digipa-low-impact -Eugeen Van Mieghem,0.48270774,digipa-med-impact -Alexei Harlamoff,0.48255378,digipa-med-impact -Jeff Legg,0.48249072,digipa-med-impact -Elizabeth Murray,0.48227608,digipa-med-impact -Hugo Heyrman,0.48213717,digipa-med-impact -Adrian Paul Allinson,0.48211843,digipa-med-impact -Altoon Sultan,0.4820177,digipa-med-impact -Alice Mason,0.48188528,fareast -Harriet Powers,0.48181778,digipa-med-impact -Aaron Bohrod,0.48175076,digipa-med-impact -Chris Saunders,0.41429797,digipa-low-impact -Clara Miller Burd,0.47797233,digipa-med-impact -David G. Sorensen,0.38101727,digipa-low-impact -Iwan Baan,0.4806739,digipa-med-impact -Anatoly Metlan,0.48020265,digipa-med-impact -Alfons von Czibulka,0.4801954,digipa-med-impact -Amedee Ozenfant,0.47950014,digipa-med-impact -Valerie Hegarty,0.47947168,digipa-med-impact -Hugo Anton Fisher,0.4793551,digipa-med-impact -Antonio Roybal,0.4792729,digipa-med-impact -Cui Zizhong,0.47902682,fareast -F Scott Hess,0.42582104,fineart -Julien Delval,0.47888556,digipa-med-impact -Marcin Jakubowski,0.4788583,digipa-med-impact -Anne Stokes,0.4786997,digipa-med-impact -David Palumbo,0.47632077,fineart -Hallsteinn Sigurðsson,0.47858906,digipa-med-impact -Mike Campau,0.47850558,digipa-med-impact -Giuseppe Avanzi,0.47846943,digipa-med-impact -Harry Morley,0.47836518,digipa-med-impact -Constance-Anne Parker,0.47832203,digipa-med-impact -Albert Keller,0.47825447,digipa-med-impact -Daniel Chodowiecki,0.47825167,digipa-med-impact -Alasdair Grant Taylor,0.47802624,digipa-med-impact -Maria Pascual Alberich,0.4779718,fineart -Rebeca Saray,0.41697127,digipa-low-impact -Ernő Bánk,0.47753686,digipa-med-impact -Shaddy Safadi,0.47724134,digipa-med-impact -André Castro,0.4771826,digipa-med-impact -Amiet Cuno,0.41975892,digipa-low-impact -Adi Granov,0.40670198,fineart -Allen Williams,0.47675848,digipa-med-impact -Anna Haifisch,0.47672725,digipa-med-impact -Clovis Trouille,0.47669724,digipa-med-impact -Jane Graverol,0.47655866,digipa-med-impact -Conroy Maddox,0.47645602,digipa-med-impact -Božidar Jakac,0.4763106,digipa-med-impact -George Morrison,0.47533786,digipa-med-impact -Douglas Bourgeois,0.47527707,digipa-med-impact -Cao Zhibai,0.47476804,fareast -Bradley Walker Tomlin,0.47462896,digipa-low-impact -Dave Dorman,0.46852386,fineart -Stevan Dohanos,0.47452107,fineart -John Howe,0.44144905,fineart -Fanny McIan,0.47406268,digipa-low-impact -Bholekar Srihari,0.47387534,digipa-low-impact -Giovanni Lanfranco,0.4737344,digipa-low-impact -Fred Marcellino,0.47346023,digipa-low-impact -Clyde Caldwell,0.47305286,fineart -Haukur Halldórsson,0.47275954,digipa-low-impact -Huang Gongwang,0.47269204,fareast -Brothers Grimm,0.47249007,digipa-low-impact -Ollie Hoff,0.47240657,digipa-low-impact -RHADS,0.4722166,digipa-low-impact -Constance Gordon-Cumming,0.47219282,digipa-low-impact -Anne Mccaffrey,0.4719924,digipa-low-impact -Henry Heerup,0.47190166,digipa-low-impact -Adrian Smith,0.4716923,digipa-high-impact -Harold Elliott,0.4714101,digipa-low-impact -Eric Peterson,0.47106332,digipa-low-impact -David Garner,0.47106326,digipa-low-impact -Edward Hicks,0.4708863,digipa-low-impact -Alfred Krupa,0.47052455,digipa-low-impact -Breyten Breytenbach,0.4699338,digipa-low-impact -Douglas Shuler,0.4695691,digipa-low-impact -Elaine Hamilton,0.46941522,digipa-low-impact -Kapwani Kiwanga,0.46917036,digipa-low-impact -Dan Scott,0.46897763,digipa-low-impact -Allan Brooks,0.46882123,digipa-low-impact -Ian Fairweather,0.46878594,digipa-low-impact -Arlington Nelson Lindenmuth,0.4683814,digipa-low-impact -Russell Ayto,0.4681503,digipa-low-impact -Allan Linder,0.46812692,digipa-low-impact -Bohumil Kubista,0.4679809,digipa-low-impact -Christopher Jin Baron,0.4677839,digipa-low-impact -Eero Snellman,0.46777654,digipa-low-impact -Christabel Dennison,0.4677633,digipa-low-impact -Amelia Peláez,0.46764764,digipa-low-impact -James Gurney,0.46740666,digipa-low-impact -Carles Delclaux Is,0.46734855,digipa-low-impact -George Papazov,0.42420334,digipa-low-impact -Mark Brooks,0.4672415,fineart -Anne Dunn,0.46722376,digipa-low-impact -Klaus Wittmann,0.4670704,fineart -Arvid Nyholm,0.46697336,digipa-low-impact -Georg Scholz,0.46674117,digipa-low-impact -David Spriggs,0.46671993,digipa-low-impact -Ernest Morgan,0.4665036,digipa-low-impact -Ella Guru,0.46619284,digipa-low-impact -Helen Berman,0.46614346,digipa-low-impact -Gen Paul,0.4658785,digipa-low-impact -Auseklis Ozols,0.46569023,digipa-low-impact -Amelia Robertson Hill,0.4654411,fineart -Jim Lee,0.46544096,digipa-low-impact -Anson Maddocks,0.46539295,digipa-low-impact -Chen Hong,0.46516004,fareast -Haddon Sundblom,0.46490777,digipa-low-impact -Eva Švankmajerová,0.46454152,digipa-low-impact -Antonio Cavallucci,0.4645282,digipa-low-impact -Herve Groussin,0.40050638,digipa-low-impact -Gwen Barnard,0.46400994,digipa-low-impact -Grace English,0.4638674,digipa-low-impact -Carl Critchlow,0.4636,digipa-low-impact -Ayshia Taşkın,0.463412,digipa-low-impact -Alison Watt,0.43141022,digipa-low-impact -Andre de Krayewski,0.4628024,digipa-low-impact -Hamish MacDonald,0.462645,digipa-low-impact -Ni Chuanjing,0.46254826,fareast -Frank Mason,0.46254665,digipa-low-impact -Steve Henderson,0.43113405,fineart -Eileen Aldridge,0.46210572,digipa-low-impact -Brad Rigney,0.28446302,digipa-low-impact -Ching Yeh,0.46177,fareast -Bertram Brooker,0.46176457,digipa-low-impact -Henry Bright,0.46150023,digipa-low-impact -Claire Dalby,0.46117848,digipa-low-impact -Brian Despain,0.41538632,digipa-low-impact -Anna Maria Barbara Abesch,0.4611045,digipa-low-impact -Bernardo Daddi,0.46088326,digipa-low-impact -Abraham Mintchine,0.46088243,digipa-high-impact -Alexander Carse,0.46078917,digipa-low-impact -Doc Hammer,0.46075988,digipa-low-impact -Yuumei,0.46072406,digipa-low-impact -Teophilus Tetteh,0.46064255,n -Bess Hamiti,0.46062252,digipa-low-impact -Ceferí Olivé,0.46058378,digipa-low-impact -Enrique Grau,0.46046937,digipa-low-impact -Eleanor Hughes,0.46007007,digipa-low-impact -Elizabeth Charleston,0.46001568,digipa-low-impact -Félix Ziem,0.45987016,digipa-low-impact -Eugeniusz Zak,0.45985222,digipa-low-impact -Dain Yoon,0.45977795,fareast -Gong Xian,0.4595083,digipa-low-impact -Flavia Blois,0.45950204,digipa-low-impact -Frederik Vermehren,0.45949826,digipa-low-impact -Gang Se-hwang,0.45937777,digipa-low-impact -Bjørn Wiinblad,0.45934483,digipa-low-impact -Alex Horley-Orlandelli,0.42623433,digipa-low-impact -Dr. Atl,0.459287,digipa-low-impact -Hu Jieqing,0.45889485,fareast -Amédée Ozenfant,0.4585215,digipa-low-impact -Warren Ellis,0.4584044,digipa-low-impact -Helen Dahm,0.45804346,digipa-low-impact -Anne Geddes,0.45785287,digipa-low-impact -Bikash Bhattacharjee,0.45775396,digipa-low-impact -Phil Foglio,0.457582,digipa-low-impact -Evelyn Abelson,0.4574563,digipa-low-impact -Alan Moore,0.4573369,digipa-low-impact -Josh Kao,0.45725146,fareast -Bertil Nilsson,0.45724383,digipa-low-impact -Hristofor Zhefarovich,0.457089,fineart -Edward Bailey,0.45659882,digipa-low-impact -Christopher Moeller,0.45648077,digipa-low-impact -Dóra Keresztes,0.4558745,fineart -Cory Arcangel,0.4558071,digipa-low-impact -Aleksander Kobzdej,0.45552525,digipa-low-impact -Tim Burton,0.45541722,digipa-high-impact -Chen Jiru,0.4553378,fareast -George Passantino,0.4552104,digipa-low-impact -Fuller Potter,0.4552072,digipa-low-impact -Warwick Globe,0.45516664,digipa-low-impact -Heinz Anger,0.45466962,digipa-low-impact -Elias Goldberg,0.45416242,digipa-low-impact -tokyogenso,0.45406622,fareast -Zeen Chin,0.45404464,digipa-low-impact -Albert Koetsier,0.45385844,fineart -Giuseppe Camuncoli,0.45377725,digipa-low-impact -Elsie Vera Cole,0.45377362,digipa-low-impact -Andreas Franke,0.4300047,digipa-low-impact -Constantine Andreou,0.4533816,digipa-low-impact -Elisabeth Collins,0.45337808,digipa-low-impact -Ted Nasmith,0.45302224,fineart -Antônio Parreiras,0.45269623,digipa-low-impact -Gwilym Prichard,0.45256525,digipa-low-impact -Fang Congyi,0.45240825,fareast -Huang Ding,0.45233482,fareast -Hans von Bartels,0.45200723,digipa-low-impact -Peter Elson,0.4121406,fineart -Fan Kuan,0.4513034,digipa-low-impact -Dean Roger,0.45112592,digipa-low-impact -Bernat Sanjuan,0.45074993,fareast -Fletcher Martin,0.45055175,digipa-low-impact -Gentile Tondino,0.45043385,digipa-low-impact -Ei-Q,0.45038772,digipa-low-impact -Chen Lin,0.45035738,fareast -Ted Wallace,0.4500007,digipa-low-impact -"Cornelisz Hendriksz Vroom, the Younger",0.4499252,digipa-low-impact -Alpo Jaakola,0.44981295,digipa-low-impact -Clark Voorhees,0.4495309,digipa-low-impact -Cleve Gray,0.449188,digipa-low-impact -Wolf Kahn,0.4489858,digipa-low-impact -Choi Buk,0.44892842,fareast -Frank Tinsley,0.4480373,digipa-low-impact -George Bell,0.44779524,digipa-low-impact -Fiona Stephenson,0.44761062,fineart -Carlos Trillo Name,0.4470371,digipa-low-impact -Jamie McKelvie,0.44696707,digipa-low-impact -Dennis Flanders,0.44673377,digipa-low-impact -Dulah Marie Evans,0.44662604,digipa-low-impact -Hans Schwarz,0.4463275,digipa-low-impact -Steve McCurry,0.44620228,digipa-low-impact -Bedwyr Williams,0.44616276,digipa-low-impact -Anton Graff,0.38569996,digipa-low-impact -Leticia Gillett,0.44578317,digipa-low-impact -Rafał Olbiński,0.44561762,digipa-low-impact -Artgerm,0.44555497,fineart -Adrienn Henczné Deák,0.445518,digipa-low-impact -Gu Hongzhong,0.4454906,fareast -Matt Groening,0.44518438,digipa-low-impact -Sue Bryce,0.4447164,digipa-low-impact -Armin Baumgarten,0.444061,digipa-low-impact -Araceli Gilbert,0.44399196,digipa-low-impact -Carey Morris,0.44388965,digipa-low-impact -Ignat Bednarik,0.4438085,digipa-low-impact -Frank Buchser,0.44373792,digipa-low-impact -Ben Zoeller,0.44368798,digipa-low-impact -Adam Szentpétery,0.4434548,fineart -Gene Davis,0.44343877,digipa-low-impact -Fei Danxu,0.4433627,fareast -Andrei Kolkoutine,0.44328922,digipa-low-impact -Bruce Onobrakpeya,0.42588046,n -Christoph Amberger,0.38912287,digipa-low-impact -"Fred Mitchell,",0.4432277,digipa-low-impact -Klaus Burgle,0.44295216,digipa-low-impact -Carl Hoppe,0.44270635,digipa-low-impact -Caroline Gotch,0.44263047,digipa-low-impact -Hans Mertens,0.44260004,digipa-low-impact -Mandy Disher,0.44219893,fineart -Sarah Lucas,0.4420507,digipa-low-impact -Sydney Edmunds,0.44198513,digipa-low-impact -Amos Ferguson,0.4418735,digipa-low-impact -Alton Tobey,0.4416385,digipa-low-impact -Clifford Ross,0.44139367,digipa-low-impact -Henric Trenk,0.4412782,digipa-low-impact -Claire Hummel,0.44119984,digipa-low-impact -Norman Foster,0.4411899,digipa-low-impact -Carmen Saldana,0.44076762,digipa-low-impact -Michael Whelan,0.4372847,digipa-low-impact -Carlos Berlanga,0.440354,digipa-low-impact -Gilles Beloeil,0.43997732,digipa-low-impact -Ashley Wood,0.4398396,digipa-low-impact -David Allan,0.43969798,digipa-low-impact -Mark Lovett,0.43922082,digipa-low-impact -Jed Henry,0.43882954,digipa-low-impact -Adam Bruce Thomson,0.43847767,digipa-low-impact -Horst Antes,0.4384303,digipa-low-impact -Fritz Glarner,0.43787453,digipa-low-impact -Harold McCauley,0.43760818,digipa-low-impact -Estuardo Maldonado,0.437594,digipa-low-impact -Dai Jin,0.4375449,fareast -Fabien Charuau,0.43688047,digipa-low-impact -Chica Macnab,0.4365166,digipa-low-impact -Jim Burns,0.3975072,digipa-low-impact -Santiago Calatrava,0.43651623,digipa-low-impact -Robert Maguire,0.40926617,digipa-low-impact -Cliff Childs,0.43611953,digipa-low-impact -Charles Martin,0.43582463,fareast -Elbridge Ayer Burbank,0.43572164,digipa-low-impact -Anita Kunz,0.4356005,digipa-low-impact -Colin Geller,0.43559563,digipa-low-impact -Allen Tupper True,0.43556124,digipa-low-impact -Jef Wu,0.43555313,digipa-low-impact -Jon McCoy,0.4147122,digipa-low-impact -Cedric Seaut,0.43521535,digipa-low-impact -Emily Shanks,0.43519047,digipa-low-impact -Andrew Whem,0.43512022,digipa-low-impact -Ibrahim Kodra,0.43471518,digipa-low-impact -Harrington Mann,0.4345901,digipa-low-impact -Jerry Siegel,0.43458986,digipa-low-impact -Howard Kanovitz,0.4345178,digipa-low-impact -Cicely Hey,0.43449926,digipa-low-impact -Ben Thompson,0.43436068,digipa-low-impact -Joe Bowler,0.43413073,digipa-low-impact -Lori Earley,0.43389612,digipa-low-impact -Arent Arentsz,0.43373522,digipa-low-impact -David Bailly,0.43371305,digipa-low-impact -Hans Arnold,0.4335214,digipa-low-impact -Constance Copeman,0.4334836,digipa-low-impact -Brent Heighton,0.4333118,fineart -Eric Taylor,0.43312082,digipa-low-impact -Aleksander Gine,0.4326849,digipa-low-impact -Alexander Johnston,0.4326589,digipa-low-impact -David Park,0.43235332,digipa-low-impact -Balázs Diószegi,0.432244,digipa-low-impact -Ed Binkley,0.43222216,digipa-low-impact -Eric Dinyer,0.4321258,digipa-low-impact -Susan Luo,0.43198025,fareast -Cedric Seaut (Keos Masons),0.4317356,digipa-low-impact -Lorena Alvarez Gómez,0.431683,digipa-low-impact -Fred Ludekens,0.431662,digipa-low-impact -David Begbie,0.4316218,digipa-low-impact -Ai Xuan,0.43150818,fareast -Felix-Kelly,0.43132153,digipa-low-impact -Antonín Chittussi,0.431248,digipa-low-impact -Ammi Phillips,0.43095884,digipa-low-impact -Elke Vogelsang,0.43092483,digipa-low-impact -Fathi Hassan,0.43090487,digipa-low-impact -Angela Sung,0.391746,fareast -Clément Serveau,0.43050706,digipa-low-impact -Dong Yuan,0.4303865,fareast -Hew Lorimer,0.43035403,digipa-low-impact -David Finch,0.29487437,digipa-low-impact -Bill Durgin,0.4300932,digipa-low-impact -Alexander Robertson,0.4300743,digipa-low-impact diff --git a/extensions-builtin/roll-artist/scripts/roll-artist.py b/extensions-builtin/roll-artist/scripts/roll-artist.py deleted file mode 100644 index c3bc1fd0..00000000 --- a/extensions-builtin/roll-artist/scripts/roll-artist.py +++ /dev/null @@ -1,50 +0,0 @@ -import random - -from modules import script_callbacks, shared -import gradio as gr - -art_symbol = '\U0001f3a8' # 🎨 -global_prompt = None -related_ids = {"txt2img_prompt", "txt2img_clear_prompt", "img2img_prompt", "img2img_clear_prompt" } - - -def roll_artist(prompt): - allowed_cats = set([x for x in shared.artist_db.categories() if len(shared.opts.random_artist_categories)==0 or x in shared.opts.random_artist_categories]) - artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats]) - - return prompt + ", " + artist.name if prompt != '' else artist.name - - -def add_roll_button(prompt): - roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) - - roll.click( - fn=roll_artist, - _js="update_txt2img_tokens", - inputs=[ - prompt, - ], - outputs=[ - prompt, - ] - ) - - -def after_component(component, **kwargs): - global global_prompt - - elem_id = kwargs.get('elem_id', None) - if elem_id not in related_ids: - return - - if elem_id == "txt2img_prompt": - global_prompt = component - elif elem_id == "txt2img_clear_prompt": - add_roll_button(global_prompt) - elif elem_id == "img2img_prompt": - global_prompt = component - elif elem_id == "img2img_clear_prompt": - add_roll_button(global_prompt) - - -script_callbacks.on_after_component(after_component) diff --git a/javascript/hints.js b/javascript/hints.js index f4079f96..ef410fba 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -14,7 +14,6 @@ titles = { "Seed": "A value that determines the output of random number generator - if you create an image with same parameters and seed as another image, you'll get the same result", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", - "\u{1f3a8}": "Add a random artist to the prompt.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.", "\u{1f4c2}": "Open images output directory", "\u{1f4be}": "Save style", diff --git a/modules/api/api.py b/modules/api/api.py index 2c371e6e..f2e9e884 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -126,8 +126,6 @@ class Api: self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem]) - self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str]) - self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse) @@ -390,12 +388,6 @@ class Api: return styleList - def get_artists_categories(self): - return shared.artist_db.cats - - def get_artists(self): - return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists] - def get_embeddings(self): db = sd_hijack.model_hijack.embedding_db diff --git a/modules/artists.py b/modules/artists.py deleted file mode 100644 index 3612758b..00000000 --- a/modules/artists.py +++ /dev/null @@ -1,25 +0,0 @@ -import os.path -import csv -from collections import namedtuple - -Artist = namedtuple("Artist", ['name', 'weight', 'category']) - - -class ArtistsDatabase: - def __init__(self, filename): - self.cats = set() - self.artists = [] - - if not os.path.exists(filename): - return - - with open(filename, "r", newline='', encoding="utf8") as file: - reader = csv.DictReader(file) - - for row in reader: - artist = Artist(row["artist"], float(row["score"]), row["category"]) - self.artists.append(artist) - self.cats.add(artist.category) - - def categories(self): - return sorted(self.cats) diff --git a/modules/interrogate.py b/modules/interrogate.py index 738d8ff7..19938cbb 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -5,12 +5,13 @@ from collections import namedtuple import re import torch +import torch.hub from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared -from modules import devices, paths, lowvram, modelloader +from modules import devices, paths, lowvram, modelloader, errors blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -20,27 +21,59 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") +def download_default_clip_interrogate_categories(content_dir): + print("Downloading CLIP categories...") + + tmpdir = content_dir + "_tmp" + try: + os.makedirs(tmpdir) + + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt")) + torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt")) + + os.rename(tmpdir, content_dir) + + except Exception as e: + errors.display(e, "downloading default CLIP interrogate categories") + finally: + if os.path.exists(tmpdir): + os.remove(tmpdir) + + class InterrogateModels: blip_model = None clip_model = None clip_preprocess = None - categories = None dtype = None running_on_cpu = None def __init__(self, content_dir): - self.categories = [] + self.loaded_categories = None + self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") - if os.path.exists(content_dir): - for filename in os.listdir(content_dir): + def categories(self): + if self.loaded_categories is not None: + return self.loaded_categories + + self.loaded_categories = [] + + if not os.path.exists(self.content_dir): + download_default_clip_interrogate_categories(self.content_dir) + + if os.path.exists(self.content_dir): + for filename in os.listdir(self.content_dir): m = re_topn.search(filename) topn = 1 if m is None else int(m.group(1)) - with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file: + with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.categories.append(Category(name=filename, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) + + return self.loaded_categories def load_blip_model(self): import models.blip @@ -139,7 +172,6 @@ class InterrogateModels: shared.state.begin() shared.state.job = 'interrogate' try: - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() devices.torch_gc() @@ -159,12 +191,7 @@ class InterrogateModels: image_features /= image_features.norm(dim=-1, keepdim=True) - if shared.opts.interrogate_use_builtin_artists: - artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0] - - res += ", " + artist[0] - - for name, topn, items in self.categories: + for name, topn, items in self.categories(): matches = self.rank(image_features, items, top_count=topn) for match, score in matches: if shared.opts.interrogate_return_ranks: diff --git a/modules/shared.py b/modules/shared.py index c0e11f18..72fb1934 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -9,7 +9,6 @@ from PIL import Image import gradio as gr import tqdm -import modules.artists import modules.interrogate import modules.memmon import modules.styles @@ -254,8 +253,6 @@ class State: state = State() state.server_start = time.time() -artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) - styles_filename = cmd_opts.styles_file prompt_styles = modules.styles.StyleDatabase(styles_filename) @@ -408,7 +405,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), - "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) options_templates.update(options_section(('compatibility', "Compatibility"), { @@ -419,7 +415,6 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), - "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), diff --git a/modules/ui.py b/modules/ui.py index d23b2b8e..164e0e93 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -228,17 +228,17 @@ def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_di left, _ = os.path.splitext(filename) print(interrogation_function(img), file=open(os.path.join(ii_output_dir, left + ".txt"), 'a')) - return [gr_show(True), None] + return [gr.update(), None] def interrogate(image): prompt = shared.interrogator.interrogate(image.convert("RGB")) - return gr_show(True) if prompt is None else prompt + return gr.update() if prompt is None else prompt def interrogate_deepbooru(image): prompt = deepbooru.model.tag(image) - return gr_show(True) if prompt is None else prompt + return gr.update() if prompt is None else prompt def create_seed_inputs(target_interface): @@ -1039,19 +1039,18 @@ def create_ui(): init_img_inpaint, ], outputs=[img2img_prompt, dummy_component], - show_progress=False, ) img2img_prompt.submit(**img2img_args) submit.click(**img2img_args) img2img_interrogate.click( - fn=lambda *args : process_interrogate(interrogate, *args), + fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, ) img2img_deepbooru.click( - fn=lambda *args : process_interrogate(interrogate_deepbooru, *args), + fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), **interrogate_args, ) -- cgit v1.2.3 From 184e23eb89c198b42f351a4d5ff862ee64917619 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 09:48:38 +0300 Subject: relocate tool buttons next to generate button prevent extra network tabs from putting images into wrong prompts prevent settings leaking into prompt --- html/extra-networks-card.html | 2 +- javascript/extraNetworks.js | 33 +++++++++++++-------------------- javascript/ui.js | 4 ++-- modules/ui.py | 43 +++++++++++++++++++++---------------------- style.css | 18 +++++++----------- 5 files changed, 44 insertions(+), 56 deletions(-) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 7314b063..1bdf1d27 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,4 +1,4 @@ -
    +
      diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 71e522d1..5e0d9714 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -6,49 +6,42 @@ function setupExtraNetworksForTab(tabname){ gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close')) } -var activePromptTextarea = null; -var activePositivePromptTextarea = null; +var activePromptTextarea = {}; function setupExtraNetworks(){ setupExtraNetworksForTab('txt2img') setupExtraNetworksForTab('img2img') - function registerPrompt(id, isNegative){ + function registerPrompt(tabname, id){ var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); - if (activePromptTextarea == null){ - activePromptTextarea = textarea - } - if (activePositivePromptTextarea == null && ! isNegative){ - activePositivePromptTextarea = textarea + if (! activePromptTextarea[tabname]){ + activePromptTextarea[tabname] = textarea } textarea.addEventListener("focus", function(){ - activePromptTextarea = textarea; - if(! isNegative) activePositivePromptTextarea = textarea; + activePromptTextarea[tabname] = textarea; }); } - registerPrompt('txt2img_prompt') - registerPrompt('txt2img_neg_prompt', true) - registerPrompt('img2img_prompt') - registerPrompt('img2img_neg_prompt', true) + registerPrompt('txt2img', 'txt2img_prompt') + registerPrompt('txt2img', 'txt2img_neg_prompt') + registerPrompt('img2img', 'img2img_prompt') + registerPrompt('img2img', 'img2img_neg_prompt') } onUiLoaded(setupExtraNetworks) -function cardClicked(textToAdd, allowNegativePrompt){ - textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea +function cardClicked(tabname, textToAdd, allowNegativePrompt){ + var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea") textarea.value = textarea.value + " " + textToAdd updateInput(textarea) - - return false } function saveCardPreview(event, tabname, filename){ - textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea') - button = gradioApp().getElementById(tabname + '_save_preview') + var textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea') + var button = gradioApp().getElementById(tabname + '_save_preview') textarea.value = filename updateInput(textarea) diff --git a/javascript/ui.js b/javascript/ui.js index a7e75439..77256e15 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -203,8 +203,8 @@ onUiUpdate(function(){ json_elem = gradioApp().getElementById('settings_json') if(json_elem == null) return; - textarea = json_elem.querySelector('textarea') - jsdata = textarea.value + var textarea = json_elem.querySelector('textarea') + var jsdata = textarea.value opts = JSON.parse(jsdata) executeCallbacks(optionsChangedCallbacks); diff --git a/modules/ui.py b/modules/ui.py index 164e0e93..fbc3efa0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -349,30 +349,13 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)") + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)") with gr.Row(): with gr.Column(scale=80): with gr.Row(): negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") - with gr.Column(scale=1, elem_id="roll_col"): - paste = ToolButton(value=paste_symbol, elem_id="paste") - clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") - - token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") - negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) - button_interrogate = None button_deepbooru = None if is_img2img: @@ -380,7 +363,7 @@ def create_toprow(is_img2img): button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - with gr.Column(scale=1): + with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): with gr.Row(elem_id=f"{id_part}_generate_box"): interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") skip = gr.Button('Skip', elem_id=f"{id_part}_skip") @@ -398,13 +381,29 @@ def create_toprow(is_img2img): outputs=[], ) + with gr.Row(elem_id=f"{id_part}_tools"): + paste = ToolButton(value=paste_symbol, elem_id="paste") + clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") + prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply") + save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create") + + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") + negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") + + clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[prompt, negative_prompt], + outputs=[prompt, negative_prompt], + ) + with gr.Row(elem_id=f"{id_part}_styles_row"): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply") - save_style = ToolButton(value=save_style_symbol, elem_id="style_create") - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button diff --git a/style.css b/style.css index 5e8bc2ca..04bf2982 100644 --- a/style.css +++ b/style.css @@ -124,15 +124,12 @@ height: 100%; } -#roll_col{ - min-width: unset !important; - flex-grow: 0 !important; - padding: 0 1em 0 0; +#txt2img_actions_column, #img2img_actions_column{ gap: 0; } -#roll_col > button { - margin: 0.1em 0; +#txt2img_tools, #img2img_tools{ + gap: 0.4em; } #interrogate_col{ @@ -153,7 +150,6 @@ #txt2img_styles_row, #img2img_styles_row{ gap: 0.25em; - margin-top: 0.5em; } #txt2img_styles_row > button, #img2img_styles_row > button{ @@ -164,6 +160,10 @@ padding: 0; } +#txt2img_styles > label > div, #img2img_styles > label > div{ + min-height: 3.2em; +} + #txt2img_styles ul, #img2img_styles ul{ max-height: 35em; z-index: 2000; @@ -770,10 +770,6 @@ footer { line-height: 2.4em; } -#txt2img_extra_networks, #img2img_extra_networks{ - margin-top: -1em; -} - .extra-networks > div > [id *= '_extra_']{ margin: 0.3em; } -- cgit v1.2.3 From cbfb4632585415dc914aff8c44869d792fd64c24 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 11:22:16 +0300 Subject: fix failing tests by removing then :^) --- test/basic_features/utils_test.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/basic_features/utils_test.py b/test/basic_features/utils_test.py index 94e00253..0bfc28a0 100644 --- a/test/basic_features/utils_test.py +++ b/test/basic_features/utils_test.py @@ -12,8 +12,6 @@ class UtilsTests(unittest.TestCase): self.url_face_restorers = "http://localhost:7860/sdapi/v1/face-restorers" self.url_realesrgan_models = "http://localhost:7860/sdapi/v1/realesrgan-models" self.url_prompt_styles = "http://localhost:7860/sdapi/v1/prompt-styles" - self.url_artist_categories = "http://localhost:7860/sdapi/v1/artist-categories" - self.url_artists = "http://localhost:7860/sdapi/v1/artists" self.url_embeddings = "http://localhost:7860/sdapi/v1/embeddings" def test_options_get(self): @@ -56,15 +54,9 @@ class UtilsTests(unittest.TestCase): def test_prompt_styles(self): self.assertEqual(requests.get(self.url_prompt_styles).status_code, 200) - - def test_artist_categories(self): - self.assertEqual(requests.get(self.url_artist_categories).status_code, 200) - - def test_artists(self): - self.assertEqual(requests.get(self.url_artists).status_code, 200) def test_embeddings(self): - self.assertEqual(requests.get(self.url_artists).status_code, 200) + self.assertEqual(requests.get(self.url_embeddings).status_code, 200) if __name__ == "__main__": unittest.main() -- cgit v1.2.3 From 3262e825cc542ff634e6ba2e3a162eafdc6c1bba Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 21 Jan 2023 17:42:04 +0900 Subject: add --xformers-flash-attention option & impl --- modules/sd_hijack_optimizations.py | 26 ++++++++++++++++++++++++-- modules/shared.py | 1 + 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 4fa54329..9967359b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -290,7 +290,19 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + + if shared.cmd_opts.xformers_flash_attention: + op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = op + if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + # print('xformers_attention_forward', q.shape, k.shape, v.shape) + # Flash Attention is not availabe for the input arguments. + # Fallback to default xFormers' backend. + op = None + else: + op = None + + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -365,7 +377,17 @@ def xformers_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - out = xformers.ops.memory_efficient_attention(q, k, v) + if shared.cmd_opts.xformers_flash_attention: + op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = op + if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)): + # print('xformers_attnblock_forward', q.shape, k.shape, v.shape) + # Flash Attention is not availabe for the input arguments. + # Fallback to default xFormers' backend. + op = None + else: + op = None + out = xformers.ops.memory_efficient_attention(q, k, v, op=op) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 72fb1934..23328adf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -57,6 +57,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") +parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)") parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization") -- cgit v1.2.3 From 855b9e3d1c5a1bd8c2d815d38a38bc7c410be5a8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 16:15:53 +0300 Subject: Lora support! update readme to reflect some recent changes --- README.md | 14 +- extensions-builtin/Lora/extra_networks_lora.py | 20 +++ extensions-builtin/Lora/lora.py | 198 ++++++++++++++++++++++ extensions-builtin/Lora/scripts/lora_script.py | 30 ++++ extensions-builtin/Lora/ui_extra_networks_lora.py | 35 ++++ modules/extra_networks_hypernet.py | 2 +- modules/script_callbacks.py | 15 ++ modules/ui_extra_networks.py | 2 +- webui.py | 2 + 9 files changed, 314 insertions(+), 4 deletions(-) create mode 100644 extensions-builtin/Lora/extra_networks_lora.py create mode 100644 extensions-builtin/Lora/lora.py create mode 100644 extensions-builtin/Lora/scripts/lora_script.py create mode 100644 extensions-builtin/Lora/ui_extra_networks_lora.py diff --git a/README.md b/README.md index 1ac794e8..9c0cd1ef 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ A browser interface based on Gradio library for Stable Diffusion. - Possible to change defaults/mix/max/step values for UI elements via text config - Tiling support, a checkbox to create images that can be tiled like textures - Progress bar and live image generation preview + - Can use a separate neural network to produce previews with almost none VRAM or compute requirement - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image - Styles, a way to save part of prompt and easily apply them via dropdown later - Variations, a way to generate same image but with tiny differences @@ -75,13 +76,22 @@ A browser interface based on Gradio library for Stable Diffusion. - hypernetworks and embeddings options - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime) - Clip skip -- Use Hypernetworks -- Use VAEs +- Hypernetworks +- Loras (same as Hypernetworks but more pretty) +- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt. +- Can select to load a different VAE from settings screen - Estimated completion time in progress bar - API - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions +- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions +- Now without any bad letters! +- Load checkpoints in safetensors format +- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 +- Now with a license! +- Reorder elements in the UI from settings screen +- ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py new file mode 100644 index 00000000..8f2e753e --- /dev/null +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -0,0 +1,20 @@ +from modules import extra_networks +import lora + +class ExtraNetworkLora(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('lora') + + def activate(self, p, params_list): + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + lora.load_loras(names, multipliers) + + def deactivate(self, p): + pass diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py new file mode 100644 index 00000000..7a3ad9a9 --- /dev/null +++ b/extensions-builtin/Lora/lora.py @@ -0,0 +1,198 @@ +import glob +import os +import re +import torch + +from modules import shared, devices, sd_models + +re_digits = re.compile(r"\d+") +re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") +re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") +re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") +re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") + + +def convert_diffusers_name_to_compvis(key): + def match(match_list, regex): + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, re_unet_down_blocks): + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_unet_mid_blocks): + return f"diffusion_model_middle_block_1_{m[1]}" + + if match(m, re_unet_up_blocks): + return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_text_block): + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + return key + + +class LoraOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + + +class LoraModule: + def __init__(self, name): + self.name = name + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + +class LoraUpDownModule: + def __init__(self): + self.up = None + self.down = None + + +def assign_lora_names_to_compvis_modules(sd_model): + lora_layer_mapping = {} + + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + for name, module in shared.sd_model.model.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + sd_model.lora_layer_mapping = lora_layer_mapping + + +def load_lora(name, filename): + lora = LoraModule(name) + lora.mtime = os.path.getmtime(filename) + + sd = sd_models.read_state_dict(filename) + + keys_failed_to_match = [] + + for key_diffusers, weight in sd.items(): + fullkey = convert_diffusers_name_to_compvis(key_diffusers) + key, lora_key = fullkey.split(".", 1) + + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: + keys_failed_to_match.append(key_diffusers) + continue + + if type(sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.Conv2d: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.device, dtype=devices.dtype) + + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "lora_up.weight": + lora_module.up = module + elif lora_key == "lora_down.weight": + lora_module.down = module + else: + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + + if len(keys_failed_to_match) > 0: + print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") + + return lora + + +def load_loras(names, multipliers=None): + already_loaded = {} + + for lora in loaded_loras: + if lora.name in names: + already_loaded[lora.name] = lora + + loaded_loras.clear() + + loras_on_disk = [available_loras.get(name, None) for name in names] + if any([x is None for x in loras_on_disk]): + list_available_loras() + + loras_on_disk = [available_loras.get(name, None) for name in names] + + for i, name in enumerate(names): + lora = already_loaded.get(name, None) + + lora_on_disk = loras_on_disk[i] + if lora_on_disk is not None: + if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: + lora = load_lora(name, lora_on_disk.filename) + + if lora is None: + print(f"Couldn't find Lora with name {name}") + continue + + lora.multiplier = multipliers[i] if multipliers else 1.0 + loaded_loras.append(lora) + + +def lora_forward(module, input, res): + if len(loaded_loras) == 0: + return res + + lora_layer_name = getattr(module, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is not None: + res = res + module.up(module.down(input)) * lora.multiplier + + return res + + +def lora_Linear_forward(self, input): + return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + + +def lora_Conv2d_forward(self, input): + return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + + +def list_available_loras(): + available_loras.clear() + + os.makedirs(lora_dir, exist_ok=True) + + candidates = glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + + for filename in sorted(candidates): + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + + available_loras[name] = LoraOnDisk(name, filename) + + +lora_dir = os.path.join(shared.models_path, "Lora") +available_loras = {} +loaded_loras = [] + +list_available_loras() + diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py new file mode 100644 index 00000000..60b9eb64 --- /dev/null +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -0,0 +1,30 @@ +import torch + +import lora +import extra_networks_lora +import ui_extra_networks_lora +from modules import script_callbacks, ui_extra_networks, extra_networks + + +def unload(): + torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + + +def before_ui(): + ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) + extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) + + +if not hasattr(torch.nn, 'Linear_forward_before_lora'): + torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward + +if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): + torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward + +torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Conv2d.forward = lora.lora_Conv2d_forward + +script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) +script_callbacks.on_script_unloaded(unload) +script_callbacks.on_before_ui(before_ui) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py new file mode 100644 index 00000000..65397890 --- /dev/null +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -0,0 +1,35 @@ +import os +import lora + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Lora') + + def refresh(self): + lora.list_available_loras() + + def list_items(self): + for name, lora_on_disk in lora.available_loras.items(): + path, ext = os.path.splitext(lora_on_disk.filename) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "prompt": f"", + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [lora.lora_dir] + diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index 6a0c4ba8..ff279a1f 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -17,5 +17,5 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): hypernetwork.load_hypernetworks(names, multipliers) - def deactivate(p, self): + def deactivate(self, p): pass diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a9e19236..4bb45ec7 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -73,6 +73,7 @@ callback_map = dict( callbacks_image_grid=[], callbacks_infotext_pasted=[], callbacks_script_unloaded=[], + callbacks_before_ui=[], ) @@ -189,6 +190,14 @@ def script_unloaded_callback(): report_exception(c, 'script_unloaded') +def before_ui_callback(): + for c in reversed(callback_map['callbacks_before_ui']): + try: + c.callback() + except Exception: + report_exception(c, 'before_ui') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -313,3 +322,9 @@ def on_script_unloaded(callback): the script did should be reverted here""" add_callback(callback_map['callbacks_script_unloaded'], callback) + + +def on_before_ui(callback): + """register a function to be called before the UI is created.""" + + add_callback(callback_map['callbacks_before_ui'], callback) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 253e90f7..796e879c 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ extra_pages = [] def register_page(page): - """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions""" + """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" extra_pages.append(page) diff --git a/webui.py b/webui.py index e8dd822a..88d04840 100644 --- a/webui.py +++ b/webui.py @@ -165,6 +165,8 @@ def webui(): if shared.opts.clean_temp_dir_at_start: ui_tempdir.cleanup_tmpdr() + modules.script_callbacks.before_ui_callback() + shared.demo = modules.ui.create_ui() app, local_url, share_url = shared.demo.launch( -- cgit v1.2.3 From 92fb1096dbf6403e109a8eb7bc5d18ce487ae9b5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 16:41:25 +0300 Subject: make it so that extra networks are not removed from infotext --- modules/processing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index b5deeacf..241961ac 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -561,7 +561,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: cache[0] = (required_prompts, steps) return cache[1] - p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts) + _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1]) with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): @@ -593,6 +593,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(prompts) == 0: break + prompts, _ = extra_networks.parse_prompts(prompts) + if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds) -- cgit v1.2.3 From 424cefe11878c9c7d2663381441e7efe62532180 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 17:20:24 +0300 Subject: add search box to extra networks --- javascript/extraNetworks.js | 20 ++++++++++++++++++-- modules/ui_extra_networks.py | 14 ++++++++++---- style.css | 8 ++++++++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 5e0d9714..54ded58c 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -2,8 +2,24 @@ function setupExtraNetworksForTab(tabname){ gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') - gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh')) - gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close')) + var tabs = gradioApp().querySelector('#'+tabname+'_extra_tabs > div') + var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea') + var refresh = gradioApp().getElementById(tabname+'_extra_refresh') + var close = gradioApp().getElementById(tabname+'_extra_close') + + search.classList.add('search') + tabs.appendChild(search) + tabs.appendChild(refresh) + tabs.appendChild(close) + + search.addEventListener("input", function(evt){ + searchTerm = search.value + + gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ + text = elem.querySelector('.name').textContent + elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" + }) + }); } var activePromptTextarea = {}; diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 796e879c..e2e060c8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -18,6 +18,7 @@ def register_page(page): class ExtraNetworksPage: def __init__(self, title): self.title = title + self.name = title.lower() self.card_page = shared.html("extra-networks-card.html") self.allow_negative_prompt = False @@ -34,7 +35,11 @@ class ExtraNetworksPage: dirs = "".join([f"
    • {x}
    • " for x in self.allowed_directories_for_previews()]) items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) - res = "
      " + items_html + "
      " + res = f""" +
      +{items_html} +
      +""" return res @@ -81,14 +86,15 @@ def create_ui(container, button, tabname): ui.tabname = tabname with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: - button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") - button_close = gr.Button('Close', elem_id=tabname+"_extra_close") - for page in ui.stored_extra_pages: with gr.Tab(page.title): page_elem = gr.HTML(page.create_html(ui.tabname)) ui.pages.append(page_elem) + filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) diff --git a/style.css b/style.css index 04bf2982..1e59575f 100644 --- a/style.css +++ b/style.css @@ -774,6 +774,14 @@ footer { margin: 0.3em; } + + +#txt2img_extra_networks .search, #img2img_extra_networks .search{ + display: inline-block; + max-width: 16em; + margin: 0.3em; +} + .extra-network-cards .nocards{ margin: 1.25em 0.5em 0.5em 0.5em; } -- cgit v1.2.3 From 63b824376c49013880ff44c260ea426e2899511e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 18:47:54 +0300 Subject: add --gradio-queue option to enable gradio queue --- modules/shared.py | 2 ++ webui.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/modules/shared.py b/modules/shared.py index 72fb1934..52bbb807 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -100,6 +100,8 @@ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS o parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) +parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button") + script_loading.preload_extensions(extensions.extensions_dir, parser) script_loading.preload_extensions(extensions.extensions_builtin_dir, parser) diff --git a/webui.py b/webui.py index 88d04840..d235da74 100644 --- a/webui.py +++ b/webui.py @@ -169,6 +169,9 @@ def webui(): shared.demo = modules.ui.create_ui() + if cmd_opts.gradio_queue: + shared.demo.queue(64) + app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=server_name, -- cgit v1.2.3 From a2749ec655af93d96ea7ebed85e8c1e7c5072b02 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 18:52:45 +0300 Subject: load Lora from .ckpt also --- extensions-builtin/Lora/lora.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 7a3ad9a9..6d860224 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -179,7 +179,10 @@ def list_available_loras(): os.makedirs(lora_dir, exist_ok=True) - candidates = glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + candidates = \ + glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(lora_dir, '**/*.ckpt'), recursive=True) for filename in sorted(candidates): if os.path.isdir(filename): @@ -195,4 +198,3 @@ available_loras = {} loaded_loras = [] list_available_loras() - -- cgit v1.2.3 From 3deea3413575db0ff71f20f4265f3bdc08e35453 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 19:36:08 +0300 Subject: extract extra network data from prompt earlier --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 241961ac..6e6371a1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -532,6 +532,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() + _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1]) + if p.scripts is not None: p.scripts.process(p) @@ -561,8 +563,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: cache[0] = (required_prompts, steps) return cache[1] - _, extra_network_data = extra_networks.parse_prompts(p.all_prompts[0:1]) - with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) -- cgit v1.2.3 From f53527f7786575fe60da0223bd63ea3f0a06a754 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 20:07:14 +0300 Subject: make it run on gradio < 3.16.2 --- modules/ui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index fbc3efa0..b3105901 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1897,7 +1897,7 @@ def create_ui(): if type(x) == gr.Dropdown: def check_dropdown(val): - if x.multiselect: + if getattr(x, 'multiselect', False): return all([value in x.choices for value in val]) else: return val in x.choices -- cgit v1.2.3 From f726df8a2fd2620ba245de5702e2afe620f79b91 Mon Sep 17 00:00:00 2001 From: James Tolton Date: Sat, 21 Jan 2023 12:59:05 -0500 Subject: Compile and serve js from /statica instead of inline in html --- modules/ui.py | 35 ++++++++++++++++++++++++++++++----- statica/put-static-files-here.txt | 1 + webui.py | 2 ++ 3 files changed, 33 insertions(+), 5 deletions(-) create mode 100644 statica/put-static-files-here.txt diff --git a/modules/ui.py b/modules/ui.py index fbc3efa0..d19eaf25 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -10,6 +10,7 @@ import sys import tempfile import time import traceback +from collections import OrderedDict from functools import partial, reduce import warnings @@ -1918,27 +1919,51 @@ def create_ui(): def reload_javascript(): + javascript_files = OrderedDict() with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - javascript = f'' + contents = jsfile.read() + javascript_files["script.js"] = [contents] + # javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") for basedir, filename, path in scripts_list: with open(path, "r", encoding="utf8") as jsfile: - javascript += f"\n" + contents = jsfile.read() + javascript_files[filename] = [contents] + # javascript += f"\n" if cmd_opts.theme is not None: - javascript += f"\n\n" + javascript_files["theme.js"] = [f"set_theme('{cmd_opts.theme}');"] + # javascript += f"\n\n" - javascript += f"\n" + # javascript += f"\n" + javascript_files["localization.js"] = [f"{localization.localization_js(shared.opts.localization)}"] + + compiled_name = "webui-compiled.js" + head = f""" + + """ def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace( - b'', f'{javascript}'.encode("utf8")) + b'', f'{head}'.encode("utf8")) res.init_headers() return res + for k in javascript_files: + javascript_files[k] = "\n".join(javascript_files[k]) + + # make static_path if not exists + statica_path = os.path.join(script_path, 'statica') + if not os.path.exists(statica_path): + os.mkdir(statica_path) + + javascript_out = "\n\n\n".join([f"// \n\n{v}" for k, v in javascript_files.items()]) + with open(os.path.join(script_path, "statica", compiled_name), "w", encoding="utf8") as jsfile: + jsfile.write(javascript_out) + gradio.routes.templates.TemplateResponse = template_response diff --git a/statica/put-static-files-here.txt b/statica/put-static-files-here.txt new file mode 100644 index 00000000..7cfaaa86 --- /dev/null +++ b/statica/put-static-files-here.txt @@ -0,0 +1 @@ +ayo \ No newline at end of file diff --git a/webui.py b/webui.py index d235da74..50dee700 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,7 @@ import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware +from starlette.staticfiles import StaticFiles from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion @@ -195,6 +196,7 @@ def webui(): setup_cors(app) app.add_middleware(GZipMiddleware, minimum_size=1000) + app.mount("/statica", StaticFiles(directory=os.path.join(script_path, 'statica')), name="statica") modules.progress.setup_progress_api(app) -- cgit v1.2.3 From 17af0fb95574068a1d5032ae96879dab145e173a Mon Sep 17 00:00:00 2001 From: James Tolton Date: Sat, 21 Jan 2023 13:27:05 -0500 Subject: remove commented out lines --- modules/ui.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index d19eaf25..ef85d43c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1923,7 +1923,6 @@ def reload_javascript(): with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: contents = jsfile.read() javascript_files["script.js"] = [contents] - # javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") @@ -1931,13 +1930,10 @@ def reload_javascript(): with open(path, "r", encoding="utf8") as jsfile: contents = jsfile.read() javascript_files[filename] = [contents] - # javascript += f"\n" if cmd_opts.theme is not None: javascript_files["theme.js"] = [f"set_theme('{cmd_opts.theme}');"] - # javascript += f"\n\n" - # javascript += f"\n" javascript_files["localization.js"] = [f"{localization.localization_js(shared.opts.localization)}"] compiled_name = "webui-compiled.js" -- cgit v1.2.3 From 50059ea661b63967b217e687819cf7a9081e4a0c Mon Sep 17 00:00:00 2001 From: James Tolton Date: Sat, 21 Jan 2023 14:07:48 -0500 Subject: server individually listed javascript files vs single compiled file --- modules/ui.py | 52 +++++++++++++++++---------------------- statica/put-static-files-here.txt | 1 - webui.py | 2 -- 3 files changed, 23 insertions(+), 32 deletions(-) delete mode 100644 statica/put-static-files-here.txt diff --git a/modules/ui.py b/modules/ui.py index ef85d43c..b372d29c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1919,47 +1919,41 @@ def create_ui(): def reload_javascript(): - javascript_files = OrderedDict() - with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: - contents = jsfile.read() - javascript_files["script.js"] = [contents] - scripts_list = modules.scripts.list_scripts("javascript", ".js") - + js_files = [] for basedir, filename, path in scripts_list: - with open(path, "r", encoding="utf8") as jsfile: - contents = jsfile.read() - javascript_files[filename] = [contents] + path = path[len(script_path) + 1:] + js_files.append(path) + inline = [f"{localization.localization_js(shared.opts.localization)};"] if cmd_opts.theme is not None: - javascript_files["theme.js"] = [f"set_theme('{cmd_opts.theme}');"] + inline.append(f"set_theme('{cmd_opts.theme}');", ) - javascript_files["localization.js"] = [f"{localization.localization_js(shared.opts.localization)}"] - - compiled_name = "webui-compiled.js" - head = f""" - - """ + t = int(time.time()) + head = [ + f""" + + """.strip() + ] + inline_code = "\n".join(inline) + head.append(f""" + + """.strip()) + for file in js_files: + head.append(f""" + + """.strip()) def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + head_inject = "\n".join(head) res.body = res.body.replace( - b'', f'{head}'.encode("utf8")) + b'', f'{head_inject}'.encode("utf8")) res.init_headers() return res - for k in javascript_files: - javascript_files[k] = "\n".join(javascript_files[k]) - - # make static_path if not exists - statica_path = os.path.join(script_path, 'statica') - if not os.path.exists(statica_path): - os.mkdir(statica_path) - - javascript_out = "\n\n\n".join([f"// \n\n{v}" for k, v in javascript_files.items()]) - with open(os.path.join(script_path, "statica", compiled_name), "w", encoding="utf8") as jsfile: - jsfile.write(javascript_out) - gradio.routes.templates.TemplateResponse = template_response diff --git a/statica/put-static-files-here.txt b/statica/put-static-files-here.txt deleted file mode 100644 index 7cfaaa86..00000000 --- a/statica/put-static-files-here.txt +++ /dev/null @@ -1 +0,0 @@ -ayo \ No newline at end of file diff --git a/webui.py b/webui.py index 50dee700..d235da74 100644 --- a/webui.py +++ b/webui.py @@ -8,7 +8,6 @@ import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from starlette.staticfiles import StaticFiles from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion @@ -196,7 +195,6 @@ def webui(): setup_cors(app) app.add_middleware(GZipMiddleware, minimum_size=1000) - app.mount("/statica", StaticFiles(directory=os.path.join(script_path, 'statica')), name="statica") modules.progress.setup_progress_api(app) -- cgit v1.2.3 From 035459c9a22bebcf68ac454a1f178fefe8c82054 Mon Sep 17 00:00:00 2001 From: James Tolton Date: Sat, 21 Jan 2023 14:11:13 -0500 Subject: remove dead import --- modules/ui.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modules/ui.py b/modules/ui.py index b372d29c..5fde7fc5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -10,7 +10,6 @@ import sys import tempfile import time import traceback -from collections import OrderedDict from functools import partial, reduce import warnings -- cgit v1.2.3 From 861fe750b01d5b6fa7434101d466b07a6f4b312e Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 21 Jan 2023 14:26:07 -0500 Subject: fixes ui issues with checkbox and hires. sections --- style.css | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/style.css b/style.css index 1e59575f..507acec1 100644 --- a/style.css +++ b/style.css @@ -707,12 +707,16 @@ footer { #txt2img_checkboxes, #img2img_checkboxes{ margin-bottom: 0.5em; + margin-left: 0em; } #txt2img_checkboxes > div, #img2img_checkboxes > div{ flex: 0; white-space: nowrap; min-width: auto; } +#txt2img_hires_fix{ + margin-left: -0.8em; +} .inactive{ opacity: 0.5; -- cgit v1.2.3 From ac2eb97db90fe35cdea00d3fdd4680289259bd42 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sat, 21 Jan 2023 22:43:37 +0300 Subject: fix auto fill and repair separate axisoptions --- scripts/xy_grid.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index b1badec9..8ff315a7 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -165,10 +165,14 @@ class AxisOption: self.confirm = confirm self.cost = cost self.choices = choices - self.is_img2img = False class AxisOptionImg2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = True + +class AxisOptionTxt2Img(AxisOption): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_img2img = False @@ -183,7 +187,8 @@ axis_options = [ AxisOption("CFG Scale", float, apply_field("cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), - AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma min", float, apply_field("s_tmin")), @@ -192,8 +197,8 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta")), AxisOption("Clip skip", int, apply_clip_skip), AxisOption("Denoising", float, apply_field("denoising_strength")), - AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]), - AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), + AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), + AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), ] @@ -288,7 +293,7 @@ class Script(scripts.Script): return "X/Y plot" def ui(self, is_img2img): - current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img] + current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] with gr.Row(): with gr.Column(scale=19): @@ -316,14 +321,14 @@ class Script(scripts.Script): swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) def fill(x_type): - axis = axis_options[x_type] + axis = current_axis_options[x_type] return ", ".join(axis.choices()) if axis.choices else gr.update() fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) def select_axis(x_type): - return gr.Button.update(visible=axis_options[x_type].choices is not None) + return gr.Button.update(visible=current_axis_options[x_type].choices is not None) x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) -- cgit v1.2.3 From e4e0918f58d382b5da400e680d743dcf0e66fd7f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 22:57:19 +0300 Subject: remove timestamp for js files, reformat code --- modules/ui.py | 34 ++++++++-------------------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index b5581a06..ef7becc6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1918,38 +1918,20 @@ def create_ui(): def reload_javascript(): - scripts_list = modules.scripts.list_scripts("javascript", ".js") - js_files = [] - for basedir, filename, path in scripts_list: - path = path[len(script_path) + 1:] - js_files.append(path) + head = f'\n' - inline = [f"{localization.localization_js(shared.opts.localization)};"] + inline = f"{localization.localization_js(shared.opts.localization)};" if cmd_opts.theme is not None: - inline.append(f"set_theme('{cmd_opts.theme}');", ) + inline += f"set_theme('{cmd_opts.theme}');" - t = int(time.time()) - head = [ - f""" - - """.strip() - ] - inline_code = "\n".join(inline) - head.append(f""" - - """.strip()) - for file in js_files: - head.append(f""" - - """.strip()) + head += f'\n' + + for script in modules.scripts.list_scripts("javascript", ".js"): + head += f'\n' def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - head_inject = "\n".join(head) - res.body = res.body.replace( - b'', f'{head_inject}'.encode("utf8")) + res.body = res.body.replace(b'', f'{head}'.encode("utf8")) res.init_headers() return res -- cgit v1.2.3 From 4a8fe09652b304034708d967c47901312940e852 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 23:06:18 +0300 Subject: remove the double loading text --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/ui.py b/modules/ui.py index ef7becc6..aa39a713 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -75,6 +75,7 @@ css_hide_progressbar = """ .wrap .m-12::before { content:"Loading..." } .wrap .z-20 svg { display:none!important; } .wrap .z-20::before { content:"Loading..." } +.wrap.cover-bg .z-20::before { content:"" } .progress-bar { display:none!important; } .meta-text { display:none!important; } .meta-text-center { display:none!important; } -- cgit v1.2.3 From 500d9a32c7b1f877c8f44159a9a10c594b545a80 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 23:11:37 +0300 Subject: add --lora-dir commandline option --- extensions-builtin/Lora/lora.py | 9 ++++----- extensions-builtin/Lora/preload.py | 6 ++++++ extensions-builtin/Lora/ui_extra_networks_lora.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) create mode 100644 extensions-builtin/Lora/preload.py diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 6d860224..da1797dc 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -177,12 +177,12 @@ def lora_Conv2d_forward(self, input): def list_available_loras(): available_loras.clear() - os.makedirs(lora_dir, exist_ok=True) + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = \ - glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + \ - glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + \ - glob.glob(os.path.join(lora_dir, '**/*.ckpt'), recursive=True) + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ + glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) for filename in sorted(candidates): if os.path.isdir(filename): @@ -193,7 +193,6 @@ def list_available_loras(): available_loras[name] = LoraOnDisk(name, filename) -lora_dir = os.path.join(shared.models_path, "Lora") available_loras = {} loaded_loras = [] diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py new file mode 100644 index 00000000..863dc5c0 --- /dev/null +++ b/extensions-builtin/Lora/preload.py @@ -0,0 +1,6 @@ +import os +from modules import paths + + +def preload(parser): + parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 65397890..4406f8a0 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -31,5 +31,5 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): } def allowed_directories_for_previews(self): - return [lora.lora_dir] + return [shared.cmd_opts.lora_dir] -- cgit v1.2.3 From 78f59a4e014d090bce7df3b218bfbcd7f11e0894 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 23:40:13 +0300 Subject: enable compact view for train tab prevent previews from ruining hypernetwork training --- modules/hypernetworks/hypernetwork.py | 2 ++ modules/processing.py | 8 ++++++-- modules/ui.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 80a47c79..503534e2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -715,6 +715,8 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi do_not_save_samples=True, ) + p.disable_extra_networks = True + if preview_from_txt2img: p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt diff --git a/modules/processing.py b/modules/processing.py index 6e6371a1..bc541e2f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -140,6 +140,7 @@ class StableDiffusionProcessing: self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.override_settings_restore_afterwards = override_settings_restore_afterwards self.is_using_inpainting_conditioning = False + self.disable_extra_networks = False if not seed_enable_extras: self.subseed = -1 @@ -567,7 +568,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) - extra_networks.activate(p, extra_network_data) + if not p.disable_extra_networks: + extra_networks.activate(p, extra_network_data) with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") @@ -684,7 +686,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) - extra_networks.deactivate(p, extra_network_data) + if not p.disable_extra_networks: + extra_networks.deactivate(p, extra_network_data) + devices.torch_gc() res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) diff --git a/modules/ui.py b/modules/ui.py index daebbc9f..af6dfb21 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1259,7 +1259,7 @@ def create_ui(): with gr.Row().style(equal_height=False): gr.HTML(value="

      See wiki for detailed explanation.

      ") - with gr.Row().style(equal_height=False): + with gr.Row(variant="compact").style(equal_height=False): with gr.Tabs(elem_id="train_tabs"): with gr.Tab(label="Create embedding"): -- cgit v1.2.3 From fe7a623e6b7e04bab2cfc96e8fd6cf49b48daee1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 00:02:41 +0300 Subject: add a slider for default value of added extra networks --- extensions-builtin/Lora/ui_extra_networks_lora.py | 3 ++- javascript/hints.js | 3 ++- modules/shared.py | 5 +++-- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_hypernets.py | 3 ++- 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 4406f8a0..54a80d36 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,3 +1,4 @@ +import json import os import lora @@ -26,7 +27,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, - "prompt": f"", + "prompt": json.dumps(f""), "local_preview": path + ".png", } diff --git a/javascript/hints.js b/javascript/hints.js index ef410fba..2aec71a9 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -107,7 +107,8 @@ titles = { "Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.", "Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.", "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", - "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders." + "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", + "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it." } diff --git a/modules/shared.py b/modules/shared.py index 52bbb807..00a1d64c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -398,7 +398,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), + "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), @@ -406,7 +406,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), - 'CLIP_stop_at_last_layers': OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), + "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index e2e060c8..4c88193f 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -54,7 +54,7 @@ class ExtraNetworksPage: args = { "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', - "prompt": json.dumps(item["prompt"]), + "prompt": item["prompt"], "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 312dbaf0..65d000cf 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -1,3 +1,4 @@ +import json import os from modules import shared, ui_extra_networks @@ -25,7 +26,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": path, "preview": preview, - "prompt": f"", + "prompt": json.dumps(f""), "local_preview": path + ".png", } -- cgit v1.2.3 From e5520232e853656e10e4a06f38db24f199474aba Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sat, 21 Jan 2023 23:58:59 +0300 Subject: make current_axis_options class variable --- scripts/xy_grid.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8ff315a7..98254c64 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -293,17 +293,17 @@ class Script(scripts.Script): return "X/Y plot" def ui(self, is_img2img): - current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] + self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] with gr.Row(): with gr.Column(scale=19): with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) @@ -314,21 +314,20 @@ class Script(scripts.Script): swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") def swap_axes(x_type, x_values, y_type, y_values): - nonlocal current_axis_options - return current_axis_options[y_type].label, y_values, current_axis_options[x_type].label, x_values + return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values swap_args = [x_type, x_values, y_type, y_values] swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) def fill(x_type): - axis = current_axis_options[x_type] + axis = self.current_axis_options[x_type] return ", ".join(axis.choices()) if axis.choices else gr.update() fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) def select_axis(x_type): - return gr.Button.update(visible=current_axis_options[x_type].choices is not None) + return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) @@ -403,10 +402,10 @@ class Script(scripts.Script): return valslist - x_opt = axis_options[x_type] + x_opt = self.current_axis_options[x_type] xs = process_axis(x_opt, x_values) - y_opt = axis_options[y_type] + y_opt = self.current_axis_options[y_type] ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): -- cgit v1.2.3 From f2eae6127d16a80d1516d3f6245b648eeca26330 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 00:16:26 +0300 Subject: fix broken textual inversion extras tab --- modules/ui_extra_networks_textual_inversion.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index e4a6e3bf..dbd23d2d 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -1,3 +1,4 @@ +import json import os from modules import ui_extra_networks, sd_hijack @@ -24,7 +25,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "name": embedding.name, "filename": embedding.filename, "preview": preview, - "prompt": embedding.name, + "prompt": json.dumps(embedding.name), "local_preview": path + ".preview.png", } -- cgit v1.2.3 From bf457b30fbfedb4b6eb2a198cbaa9f2ba071d31f Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 21 Jan 2023 16:21:33 -0500 Subject: compact checkbox and fix copy image btn overflow also fixes type for #tab_extensions in style.css --- modules/ui.py | 2 +- style.css | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index af6dfb21..12fc9e6a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -919,7 +919,7 @@ def create_ui(): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('img2img') elif category == "checkboxes": - with FormRow(elem_id="img2img_checkboxes"): + with FormRow(elem_id="img2img_checkboxes", variant="compact"): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling") diff --git a/style.css b/style.css index 507acec1..b215405d 100644 --- a/style.css +++ b/style.css @@ -589,7 +589,7 @@ canvas[key="mask"] { /* Extensions */ -#tab_extensions table``{ +#tab_extensions table{ border-collapse: collapse; } @@ -718,6 +718,10 @@ footer { margin-left: -0.8em; } +#img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{ + margin-left: 0em; +} + .inactive{ opacity: 0.5; } -- cgit v1.2.3 From 5560150fdaf5d974a122f0b226d6abe24dea12c0 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 21 Jan 2023 16:58:45 -0500 Subject: aligns the axis buttons in x/y plot --- scripts/xy_grid.py | 2 +- style.css | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8ff315a7..0caece09 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -307,7 +307,7 @@ class Script(scripts.Script): y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) - with gr.Row(variant="compact"): + with gr.Row(variant="compact", elem_id="axis_options"): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) diff --git a/style.css b/style.css index b215405d..bf8260d7 100644 --- a/style.css +++ b/style.css @@ -722,6 +722,10 @@ footer { margin-left: 0em; } +#axis_options { + margin-left: 0em; +} + .inactive{ opacity: 0.5; } -- cgit v1.2.3 From 2621566153920eb70bfa439f3d7c126ee8d36ec8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 08:05:21 +0300 Subject: attention ctrl+up/down enhancements --- javascript/edit-attention.js | 110 ++++++++++++++++++++++++++----------------- modules/shared.py | 8 ++-- 2 files changed, 71 insertions(+), 47 deletions(-) diff --git a/javascript/edit-attention.js b/javascript/edit-attention.js index cec6a530..619bb1fa 100644 --- a/javascript/edit-attention.js +++ b/javascript/edit-attention.js @@ -1,74 +1,96 @@ -addEventListener('keydown', (event) => { +function keyupEditAttention(event){ let target = event.originalTarget || event.composedPath()[0]; if (!target.matches("[id*='_toprow'] textarea.gr-text-input[placeholder]")) return; if (! (event.metaKey || event.ctrlKey)) return; - - let plus = "ArrowUp" - let minus = "ArrowDown" - if (event.key != plus && event.key != minus) return; + let isPlus = event.key == "ArrowUp" + let isMinus = event.key == "ArrowDown" + if (!isPlus && !isMinus) return; let selectionStart = target.selectionStart; let selectionEnd = target.selectionEnd; - // If the user hasn't selected anything, let's select their current parenthesis block - if (selectionStart === selectionEnd) { + let text = target.value; + + function selectCurrentParenthesisBlock(OPEN, CLOSE){ + if (selectionStart !== selectionEnd) return false; + // Find opening parenthesis around current cursor - const before = target.value.substring(0, selectionStart); - let beforeParen = before.lastIndexOf("("); - if (beforeParen == -1) return; - let beforeParenClose = before.lastIndexOf(")"); + const before = text.substring(0, selectionStart); + let beforeParen = before.lastIndexOf(OPEN); + if (beforeParen == -1) return false; + let beforeParenClose = before.lastIndexOf(CLOSE); while (beforeParenClose !== -1 && beforeParenClose > beforeParen) { - beforeParen = before.lastIndexOf("(", beforeParen - 1); - beforeParenClose = before.lastIndexOf(")", beforeParenClose - 1); + beforeParen = before.lastIndexOf(OPEN, beforeParen - 1); + beforeParenClose = before.lastIndexOf(CLOSE, beforeParenClose - 1); } // Find closing parenthesis around current cursor - const after = target.value.substring(selectionStart); - let afterParen = after.indexOf(")"); - if (afterParen == -1) return; - let afterParenOpen = after.indexOf("("); + const after = text.substring(selectionStart); + let afterParen = after.indexOf(CLOSE); + if (afterParen == -1) return false; + let afterParenOpen = after.indexOf(OPEN); while (afterParenOpen !== -1 && afterParen > afterParenOpen) { - afterParen = after.indexOf(")", afterParen + 1); - afterParenOpen = after.indexOf("(", afterParenOpen + 1); + afterParen = after.indexOf(CLOSE, afterParen + 1); + afterParenOpen = after.indexOf(OPEN, afterParenOpen + 1); } - if (beforeParen === -1 || afterParen === -1) return; + if (beforeParen === -1 || afterParen === -1) return false; // Set the selection to the text between the parenthesis - const parenContent = target.value.substring(beforeParen + 1, selectionStart + afterParen); + const parenContent = text.substring(beforeParen + 1, selectionStart + afterParen); const lastColon = parenContent.lastIndexOf(":"); selectionStart = beforeParen + 1; selectionEnd = selectionStart + lastColon; target.setSelectionRange(selectionStart, selectionEnd); - } + return true; + } + + // If the user hasn't selected anything, let's select their current parenthesis block + if(! selectCurrentParenthesisBlock('<', '>')){ + selectCurrentParenthesisBlock('(', ')') + } event.preventDefault(); - if (selectionStart == 0 || target.value[selectionStart - 1] != "(") { - target.value = target.value.slice(0, selectionStart) + - "(" + target.value.slice(selectionStart, selectionEnd) + ":1.0)" + - target.value.slice(selectionEnd); + closeCharacter = ')' + delta = opts.keyedit_precision_attention + + if (selectionStart > 0 && text[selectionStart - 1] == '<'){ + closeCharacter = '>' + delta = opts.keyedit_precision_extra + } else if (selectionStart == 0 || text[selectionStart - 1] != "(") { + + // do not include spaces at the end + while(selectionEnd > selectionStart && text[selectionEnd-1] == ' '){ + selectionEnd -= 1; + } + if(selectionStart == selectionEnd){ + return + } - target.focus(); - target.selectionStart = selectionStart + 1; - target.selectionEnd = selectionEnd + 1; + text = text.slice(0, selectionStart) + "(" + text.slice(selectionStart, selectionEnd) + ":1.0)" + text.slice(selectionEnd); - } else { - end = target.value.slice(selectionEnd + 1).indexOf(")") + 1; - weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end)); - if (isNaN(weight)) return; - if (event.key == minus) weight -= 0.1; - if (event.key == plus) weight += 0.1; + selectionStart += 1; + selectionEnd += 1; + } - weight = parseFloat(weight.toPrecision(12)); + end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1; + weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end)); + if (isNaN(weight)) return; - target.value = target.value.slice(0, selectionEnd + 1) + - weight + - target.value.slice(selectionEnd + 1 + end - 1); + weight += isPlus ? delta : -delta; + weight = parseFloat(weight.toPrecision(12)); + if(String(weight).length == 1) weight += ".0" - target.focus(); - target.selectionStart = selectionStart; - target.selectionEnd = selectionEnd; - } + text = text.slice(0, selectionEnd + 1) + weight + text.slice(selectionEnd + 1 + end - 1); + + target.focus(); + target.value = text; + target.selectionStart = selectionStart; + target.selectionEnd = selectionEnd; updateInput(target) -}); +} + +addEventListener('keydown', (event) => { + keyupEditAttention(event); +}); \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index 00a1d64c..d68ac296 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -444,9 +444,11 @@ options_templates.update(options_section(('ui', "User interface"), { "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"), "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"), - 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), - 'ui_reorder': OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), - 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), + "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), + "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), + "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) options_templates.update(options_section(('ui', "Live previews"), { -- cgit v1.2.3 From 0792fae078ba362a5119f56d84e3f490a88690ae Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 08:20:48 +0300 Subject: fix missing field for aesthetic embedding extension --- modules/sd_disable_initialization.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index c72d8efc..e90aa9fe 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -41,7 +41,9 @@ class DisableInitialization: return self.create_model_and_transforms(*args, pretrained=None, **kwargs) def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs): - return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) + res = self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs) + res.name_or_path = pretrained_model_name_or_path + return res def transformers_modeling_utils_load_pretrained_model(*args, **kwargs): args = args[0:3] + ('/', ) + args[4:] # resolved_archive_file; must set it to something to prevent what seems to be a bug -- cgit v1.2.3 From 112416d04171e4bee673f0adc9bd3aeba87ec71a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 10:17:12 +0300 Subject: add option to discard weights in checkpoint merger UI --- modules/extras.py | 9 ++++++++- modules/ui.py | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index 1218f88f..385430dc 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,6 +1,7 @@ from __future__ import annotations import math import os +import re import sys import traceback import shutil @@ -285,7 +286,7 @@ def to_half(tensor, enable): return tensor -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae): +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): shared.state.begin() shared.state.job = 'model-merge' @@ -430,6 +431,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ for key in theta_0.keys(): theta_0[key] = to_half(theta_0[key], save_as_half) + if discard_weights: + regex = re.compile(discard_weights) + for key in list(theta_0): + if re.search(regex, key): + theta_0.pop(key, None) + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path filename = filename_generator() if custom_name == '' else custom_name diff --git a/modules/ui.py b/modules/ui.py index af6dfb21..eb4b7e6b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1248,6 +1248,9 @@ def create_ui(): bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") + with FormRow(): + discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") + with gr.Row(): modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') @@ -1838,6 +1841,7 @@ def create_ui(): checkpoint_format, config_source, bake_in_vae, + discard_weights, ], outputs=[ primary_model_name, -- cgit v1.2.3 From 837ec11828a766f6d8109402ed8c856bc16c610a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 10:17:26 +0300 Subject: hint for discarding layers --- javascript/hints.js | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/javascript/hints.js b/javascript/hints.js index 2aec71a9..833543f0 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -108,7 +108,8 @@ titles = { "Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.", "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", - "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it." + "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", + "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights." } -- cgit v1.2.3 From 159f05314d0cf55e3891e4e4510ebfc861fa6da5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 10:30:55 +0300 Subject: make extra networks search case-insensitive --- javascript/extraNetworks.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 54ded58c..c5a9adb3 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -13,10 +13,10 @@ function setupExtraNetworksForTab(tabname){ tabs.appendChild(close) search.addEventListener("input", function(evt){ - searchTerm = search.value + searchTerm = search.value.toLowerCase() gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){ - text = elem.querySelector('.name').textContent + text = elem.querySelector('.name').textContent.toLowerCase() elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : "" }) }); -- cgit v1.2.3 From 35419b274614984e2b511a6ad34f37e41481c809 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 11:00:05 +0300 Subject: add an option to reorder tabs for extra networks --- javascript/hints.js | 3 ++- modules/shared.py | 1 + modules/ui_extra_networks.py | 18 +++++++++++++++++- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/javascript/hints.js b/javascript/hints.js index 833543f0..3cf10e20 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -109,7 +109,8 @@ titles = { "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders.", "Multiplier for extra networks": "When adding extra network such as Hypernetwork or Lora to prompt, use this multiplier for it.", - "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights." + "Discard weights with matching name": "Regular expression; if weights's name matches it, the weights is not written to the resulting checkpoint. Use ^model_ema to discard EMA weights.", + "Extra networks tab order": "Comma-separated list of tab names; tabs listed here will appear in the extra networks UI first and in order lsited." } diff --git a/modules/shared.py b/modules/shared.py index d68ac296..cd78e50a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -448,6 +448,7 @@ options_templates.update(options_section(('ui', "User interface"), { "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing ", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), "quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"), "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), + "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 4c88193f..285c8ffe 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -79,6 +79,22 @@ class ExtraNetworksUi: self.tabname = None +def pages_in_preferred_order(pages): + tab_order = [x.lower().strip() for x in shared.opts.ui_extra_networks_tab_reorder.split(",")] + + def tab_name_score(name): + name = name.lower() + for i, possible_match in enumerate(tab_order): + if possible_match in name: + return i + + return len(pages) + + tab_scores = {page.name: (tab_name_score(page.name), original_index) for original_index, page in enumerate(pages)} + + return sorted(pages, key=lambda x: tab_scores[x.name]) + + def create_ui(container, button, tabname): ui = ExtraNetworksUi() ui.pages = [] @@ -86,7 +102,7 @@ def create_ui(container, button, tabname): ui.tabname = tabname with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: - for page in ui.stored_extra_pages: + for page in pages_in_preferred_order(ui.stored_extra_pages): with gr.Tab(page.title): page_elem = gr.HTML(page.create_html(ui.tabname)) ui.pages.append(page_elem) -- cgit v1.2.3 From c98cb0f8ecc904666f47684e238dd022039ca16f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 11:04:02 +0300 Subject: amend previous commit to work in a proper fashion when saving previews --- modules/ui_extra_networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 285c8ffe..af2b8071 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -98,11 +98,11 @@ def pages_in_preferred_order(pages): def create_ui(container, button, tabname): ui = ExtraNetworksUi() ui.pages = [] - ui.stored_extra_pages = extra_pages.copy() + ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.tabname = tabname with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: - for page in pages_in_preferred_order(ui.stored_extra_pages): + for page in ui.stored_extra_pages: with gr.Tab(page.title): page_elem = gr.HTML(page.create_html(ui.tabname)) ui.pages.append(page_elem) -- cgit v1.2.3 From 43ac9ff205910e8207dfd45a842577344d399a92 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:26:40 +0300 Subject: Split history extras.py to postprocessing.py --- modules/extras.py | 466 ---------------------------------------------- modules/postprocessing.py | 466 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 466 insertions(+), 466 deletions(-) delete mode 100644 modules/extras.py create mode 100644 modules/postprocessing.py diff --git a/modules/extras.py b/modules/extras.py deleted file mode 100644 index 385430dc..00000000 --- a/modules/extras.py +++ /dev/null @@ -1,466 +0,0 @@ -from __future__ import annotations -import math -import os -import re -import sys -import traceback -import shutil - -import numpy as np -from PIL import Image - -import torch -import tqdm - -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model -from modules.ui import plaintext_to_html -import modules.codeformer_model -import gradio as gr -import safetensors.torch - -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - - -def run_pnginfo(image): - if image is None: - return '', '', '' - - geninfo, items = images.read_info_from_image(image) - items = {**{'parameters': geninfo}, **items} - - info = '' - for key, text in items.items(): - info += f""" -
      -

      {plaintext_to_html(str(key))}

      -

      {plaintext_to_html(str(text))}

      -
      -""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

      {message}

      " - - return '', geninfo, info - - -def create_config(ckpt_result, config_source, a, b, c): - def config(x): - res = sd_models.find_checkpoint_config(x) if x else None - return res if res != shared.sd_default_config else None - - if config_source == 0: - cfg = config(a) or config(b) or config(c) - elif config_source == 1: - cfg = config(b) - elif config_source == 2: - cfg = config(c) - else: - cfg = None - - if cfg is None: - return - - filename, _ = os.path.splitext(ckpt_result) - checkpoint_filename = filename + ".yaml" - - print("Copying config:") - print(" from:", cfg) - print(" to:", checkpoint_filename) - shutil.copyfile(cfg, checkpoint_filename) - - -checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - - -def to_half(tensor, enable): - if enable and tensor.dtype == torch.float: - return tensor.half() - - return tensor - - -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): - shared.state.begin() - shared.state.job = 'model-merge' - - def fail(message): - shared.state.textinfo = message - shared.state.end() - return [*[gr.update() for _ in range(4)], message] - - def weighted_sum(theta0, theta1, alpha): - return ((1 - alpha) * theta0) + (alpha * theta1) - - def get_difference(theta1, theta2): - return theta1 - theta2 - - def add_difference(theta0, theta1_2_diff, alpha): - return theta0 + (alpha * theta1_2_diff) - - def filename_weighted_sum(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - Ma = round(1 - multiplier, 2) - Mb = round(multiplier, 2) - - return f"{Ma}({a}) + {Mb}({b})" - - def filename_add_difference(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - c = tertiary_model_info.model_name - M = round(multiplier, 2) - - return f"{a} + {M}({b} - {c})" - - def filename_nothing(): - return primary_model_info.model_name - - theta_funcs = { - "Weighted sum": (filename_weighted_sum, None, weighted_sum), - "Add difference": (filename_add_difference, get_difference, add_difference), - "No interpolation": (filename_nothing, None, None), - } - filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] - shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) - - if not primary_model_name: - return fail("Failed: Merging requires a primary model.") - - primary_model_info = sd_models.checkpoints_list[primary_model_name] - - if theta_func2 and not secondary_model_name: - return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None - - if theta_func1 and not tertiary_model_name: - return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None - - result_is_inpainting_model = False - - if theta_func2: - shared.state.textinfo = f"Loading B" - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - else: - theta_1 = None - - if theta_func1: - shared.state.textinfo = f"Loading C" - print(f"Loading {tertiary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') - - shared.state.textinfo = 'Merging B and C' - shared.state.sampling_steps = len(theta_1.keys()) - for key in tqdm.tqdm(theta_1.keys()): - if key in checkpoint_dict_skip_on_merge: - continue - - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - - shared.state.sampling_step += 1 - del theta_2 - - shared.state.nextjob() - - shared.state.textinfo = f"Loading {primary_model_info.filename}..." - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print("Merging...") - shared.state.textinfo = 'Merging A and B' - shared.state.sampling_steps = len(theta_0.keys()) - for key in tqdm.tqdm(theta_0.keys()): - if theta_1 and 'model' in key and key in theta_1: - - if key in checkpoint_dict_skip_on_merge: - continue - - a = theta_0[key] - b = theta_1[key] - - # this enables merging an inpainting model (A) with another one (B); - # where normal model would have 4 channels, for latenst space, inpainting model would - # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 - if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: - if a.shape[1] == 4 and b.shape[1] == 9: - raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True - else: - theta_0[key] = theta_func2(a, b, multiplier) - - theta_0[key] = to_half(theta_0[key], save_as_half) - - shared.state.sampling_step += 1 - - del theta_1 - - bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) - if bake_in_vae_filename is not None: - print(f"Baking in VAE from {bake_in_vae_filename}") - shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - - for key in vae_dict.keys(): - theta_0_key = 'first_stage_model.' + key - if theta_0_key in theta_0: - theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) - - del vae_dict - - if save_as_half and not theta_func2: - for key in theta_0.keys(): - theta_0[key] = to_half(theta_0[key], save_as_half) - - if discard_weights: - regex = re.compile(discard_weights) - for key in list(theta_0): - if re.search(regex, key): - theta_0.pop(key, None) - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - - filename = filename_generator() if custom_name == '' else custom_name - filename += ".inpainting" if result_is_inpainting_model else "" - filename += "." + checkpoint_format - - output_modelname = os.path.join(ckpt_dir, filename) - - shared.state.nextjob() - shared.state.textinfo = "Saving" - print(f"Saving to {output_modelname}...") - - _, extension = os.path.splitext(output_modelname) - if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) - else: - torch.save(theta_0, output_modelname) - - sd_models.list_models() - - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - - print(f"Checkpoint saved to {output_modelname}.") - shared.state.textinfo = "Checkpoint saved" - shared.state.end() - - return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/postprocessing.py b/modules/postprocessing.py new file mode 100644 index 00000000..385430dc --- /dev/null +++ b/modules/postprocessing.py @@ -0,0 +1,466 @@ +from __future__ import annotations +import math +import os +import re +import sys +import traceback +import shutil + +import numpy as np +from PIL import Image + +import torch +import tqdm + +from typing import Callable, List, OrderedDict, Tuple +from functools import partial +from dataclasses import dataclass + +from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae +from modules.shared import opts +import modules.gfpgan_model +from modules.ui import plaintext_to_html +import modules.codeformer_model +import gradio as gr +import safetensors.torch + +class LruCache(OrderedDict): + @dataclass(frozen=True) + class Key: + image_hash: int + info_hash: int + args_hash: int + + @dataclass + class Value: + image: Image.Image + info: str + + def __init__(self, max_size: int = 5, *args, **kwargs): + super().__init__(*args, **kwargs) + self._max_size = max_size + + def get(self, key: LruCache.Key) -> LruCache.Value: + ret = super().get(key) + if ret is not None: + self.move_to_end(key) # Move to end of eviction list + return ret + + def put(self, key: LruCache.Key, value: LruCache.Value) -> None: + self[key] = value + while len(self) > self._max_size: + self.popitem(last=False) + + +cached_images: LruCache = LruCache(max_size=5) + + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + devices.torch_gc() + + shared.state.begin() + shared.state.job = 'extras' + + imageArr = [] + # Also keep track of original file names + imageNameArr = [] + outputs = [] + + if extras_mode == 1: + #convert file to pillow image + for img in image_folder: + image = Image.open(img) + imageArr.append(image) + imageNameArr.append(os.path.splitext(img.orig_name)[0]) + elif extras_mode == 2: + assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + + if input_dir == '': + return outputs, "Please select an input directory.", '' + image_list = shared.listfiles(input_dir) + for img in image_list: + try: + image = Image.open(img) + except Exception: + continue + imageArr.append(image) + imageNameArr.append(img) + else: + imageArr.append(image) + imageNameArr.append(None) + + if extras_mode == 2 and output_dir != '': + outpath = output_dir + else: + outpath = opts.outdir_samples or opts.outdir_extras_samples + + # Extra operation definitions + + def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-gfpgan' + restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(image, res, gfpgan_visibility) + + info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" + return (res, info) + + def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-codeformer' + restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(image, res, codeformer_visibility) + + info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" + return (res, info) + + def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): + shared.state.job = 'extras-upscale' + upscaler = shared.sd_upscalers[scaler_index] + res = upscaler.scaler.upscale(image, resize, upscaler.data_path) + if mode == 1 and crop: + cropped = Image.new("RGB", (resize_w, resize_h)) + cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) + res = cropped + return res + + def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text + nonlocal upscaling_resize + if resize_mode == 1: + upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) + crop_info = " (crop)" if upscaling_crop else "" + info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" + return (image, info) + + @dataclass + class UpscaleParams: + upscaler_idx: int + blend_alpha: float + + def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: + blended_result: Image.Image = None + image_hash: str = hash(np.array(image.getdata()).tobytes()) + for upscaler in params: + upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, + upscaling_resize_w, upscaling_resize_h, upscaling_crop) + cache_key = LruCache.Key(image_hash=image_hash, + info_hash=hash(info), + args_hash=hash(upscale_args)) + cached_entry = cached_images.get(cache_key) + if cached_entry is None: + res = upscale(image, *upscale_args) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + cached_images.put(cache_key, LruCache.Value(image=res, info=info)) + else: + res, info = cached_entry.image, cached_entry.info + + if blended_result is None: + blended_result = res + else: + blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) + return (blended_result, info) + + # Build a list of operations to run + facefix_ops: List[Callable] = [] + facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] + facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] + + upscale_ops: List[Callable] = [] + upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] + + if upscaling_resize != 0: + step_params: List[UpscaleParams] = [] + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) + + upscale_ops.append(partial(run_upscalers_blend, step_params)) + + extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) + + for image, image_name in zip(imageArr, imageNameArr): + if image is None: + return outputs, "Please select an input image.", '' + + shared.state.textinfo = f'Processing image {image_name}' + + existing_pnginfo = image.info or {} + + image = image.convert("RGB") + info = "" + # Run each operation on each image + for op in extras_ops: + image, info = op(image, info) + + if opts.use_original_name_batch and image_name is not None: + basename = os.path.splitext(os.path.basename(image_name))[0] + else: + basename = '' + + if opts.enable_pnginfo: # append info before save + image.info = existing_pnginfo + image.info["extras"] = info + + if save_output: + # Add upscaler name as a suffix. + suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" + # Add second upscaler if applicable. + if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: + suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" + + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) + + if extras_mode != 2 or show_extras_results : + outputs.append(image) + + devices.torch_gc() + + return outputs, plaintext_to_html(info), '' + +def clear_cache(): + cached_images.clear() + + +def run_pnginfo(image): + if image is None: + return '', '', '' + + geninfo, items = images.read_info_from_image(image) + items = {**{'parameters': geninfo}, **items} + + info = '' + for key, text in items.items(): + info += f""" +
      +

      {plaintext_to_html(str(key))}

      +

      {plaintext_to_html(str(text))}

      +
      +""".strip()+"\n" + + if len(info) == 0: + message = "Nothing found in the image." + info = f"

      {message}

      " + + return '', geninfo, info + + +def create_config(ckpt_result, config_source, a, b, c): + def config(x): + res = sd_models.find_checkpoint_config(x) if x else None + return res if res != shared.sd_default_config else None + + if config_source == 0: + cfg = config(a) or config(b) or config(c) + elif config_source == 1: + cfg = config(b) + elif config_source == 2: + cfg = config(c) + else: + cfg = None + + if cfg is None: + return + + filename, _ = os.path.splitext(ckpt_result) + checkpoint_filename = filename + ".yaml" + + print("Copying config:") + print(" from:", cfg) + print(" to:", checkpoint_filename) + shutil.copyfile(cfg, checkpoint_filename) + + +checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + + +def to_half(tensor, enable): + if enable and tensor.dtype == torch.float: + return tensor.half() + + return tensor + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): + shared.state.begin() + shared.state.job = 'model-merge' + + def fail(message): + shared.state.textinfo = message + shared.state.end() + return [*[gr.update() for _ in range(4)], message] + + def weighted_sum(theta0, theta1, alpha): + return ((1 - alpha) * theta0) + (alpha * theta1) + + def get_difference(theta1, theta2): + return theta1 - theta2 + + def add_difference(theta0, theta1_2_diff, alpha): + return theta0 + (alpha * theta1_2_diff) + + def filename_weighted_sum(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + Ma = round(1 - multiplier, 2) + Mb = round(multiplier, 2) + + return f"{Ma}({a}) + {Mb}({b})" + + def filename_add_difference(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + c = tertiary_model_info.model_name + M = round(multiplier, 2) + + return f"{a} + {M}({b} - {c})" + + def filename_nothing(): + return primary_model_info.model_name + + theta_funcs = { + "Weighted sum": (filename_weighted_sum, None, weighted_sum), + "Add difference": (filename_add_difference, get_difference, add_difference), + "No interpolation": (filename_nothing, None, None), + } + filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] + shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) + + if not primary_model_name: + return fail("Failed: Merging requires a primary model.") + + primary_model_info = sd_models.checkpoints_list[primary_model_name] + + if theta_func2 and not secondary_model_name: + return fail("Failed: Merging requires a secondary model.") + + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None + + if theta_func1 and not tertiary_model_name: + return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") + + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None + + result_is_inpainting_model = False + + if theta_func2: + shared.state.textinfo = f"Loading B" + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + else: + theta_1 = None + + if theta_func1: + shared.state.textinfo = f"Loading C" + print(f"Loading {tertiary_model_info.filename}...") + theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + + shared.state.textinfo = 'Merging B and C' + shared.state.sampling_steps = len(theta_1.keys()) + for key in tqdm.tqdm(theta_1.keys()): + if key in checkpoint_dict_skip_on_merge: + continue + + if 'model' in key: + if key in theta_2: + t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) + theta_1[key] = theta_func1(theta_1[key], t2) + else: + theta_1[key] = torch.zeros_like(theta_1[key]) + + shared.state.sampling_step += 1 + del theta_2 + + shared.state.nextjob() + + shared.state.textinfo = f"Loading {primary_model_info.filename}..." + print(f"Loading {primary_model_info.filename}...") + theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') + + print("Merging...") + shared.state.textinfo = 'Merging A and B' + shared.state.sampling_steps = len(theta_0.keys()) + for key in tqdm.tqdm(theta_0.keys()): + if theta_1 and 'model' in key and key in theta_1: + + if key in checkpoint_dict_skip_on_merge: + continue + + a = theta_0[key] + b = theta_1[key] + + # this enables merging an inpainting model (A) with another one (B); + # where normal model would have 4 channels, for latenst space, inpainting model would + # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 + if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: + if a.shape[1] == 4 and b.shape[1] == 9: + raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") + + assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" + + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) + result_is_inpainting_model = True + else: + theta_0[key] = theta_func2(a, b, multiplier) + + theta_0[key] = to_half(theta_0[key], save_as_half) + + shared.state.sampling_step += 1 + + del theta_1 + + bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) + if bake_in_vae_filename is not None: + print(f"Baking in VAE from {bake_in_vae_filename}") + shared.state.textinfo = 'Baking in VAE' + vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') + + for key in vae_dict.keys(): + theta_0_key = 'first_stage_model.' + key + if theta_0_key in theta_0: + theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) + + del vae_dict + + if save_as_half and not theta_func2: + for key in theta_0.keys(): + theta_0[key] = to_half(theta_0[key], save_as_half) + + if discard_weights: + regex = re.compile(discard_weights) + for key in list(theta_0): + if re.search(regex, key): + theta_0.pop(key, None) + + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path + + filename = filename_generator() if custom_name == '' else custom_name + filename += ".inpainting" if result_is_inpainting_model else "" + filename += "." + checkpoint_format + + output_modelname = os.path.join(ckpt_dir, filename) + + shared.state.nextjob() + shared.state.textinfo = "Saving" + print(f"Saving to {output_modelname}...") + + _, extension = os.path.splitext(output_modelname) + if extension.lower() == ".safetensors": + safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) + else: + torch.save(theta_0, output_modelname) + + sd_models.list_models() + + create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + + print(f"Checkpoint saved to {output_modelname}.") + shared.state.textinfo = "Checkpoint saved" + shared.state.end() + + return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] -- cgit v1.2.3 From b238b14ee459486c4734cc2899b83f547813a467 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:26:40 +0300 Subject: Split history extras.py to postprocessing.py --- modules/extras.py | 466 ------------------------------------------------------ modules/temp | 466 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 466 insertions(+), 466 deletions(-) delete mode 100644 modules/extras.py create mode 100644 modules/temp diff --git a/modules/extras.py b/modules/extras.py deleted file mode 100644 index 385430dc..00000000 --- a/modules/extras.py +++ /dev/null @@ -1,466 +0,0 @@ -from __future__ import annotations -import math -import os -import re -import sys -import traceback -import shutil - -import numpy as np -from PIL import Image - -import torch -import tqdm - -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model -from modules.ui import plaintext_to_html -import modules.codeformer_model -import gradio as gr -import safetensors.torch - -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - - -def run_pnginfo(image): - if image is None: - return '', '', '' - - geninfo, items = images.read_info_from_image(image) - items = {**{'parameters': geninfo}, **items} - - info = '' - for key, text in items.items(): - info += f""" -
      -

      {plaintext_to_html(str(key))}

      -

      {plaintext_to_html(str(text))}

      -
      -""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

      {message}

      " - - return '', geninfo, info - - -def create_config(ckpt_result, config_source, a, b, c): - def config(x): - res = sd_models.find_checkpoint_config(x) if x else None - return res if res != shared.sd_default_config else None - - if config_source == 0: - cfg = config(a) or config(b) or config(c) - elif config_source == 1: - cfg = config(b) - elif config_source == 2: - cfg = config(c) - else: - cfg = None - - if cfg is None: - return - - filename, _ = os.path.splitext(ckpt_result) - checkpoint_filename = filename + ".yaml" - - print("Copying config:") - print(" from:", cfg) - print(" to:", checkpoint_filename) - shutil.copyfile(cfg, checkpoint_filename) - - -checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - - -def to_half(tensor, enable): - if enable and tensor.dtype == torch.float: - return tensor.half() - - return tensor - - -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): - shared.state.begin() - shared.state.job = 'model-merge' - - def fail(message): - shared.state.textinfo = message - shared.state.end() - return [*[gr.update() for _ in range(4)], message] - - def weighted_sum(theta0, theta1, alpha): - return ((1 - alpha) * theta0) + (alpha * theta1) - - def get_difference(theta1, theta2): - return theta1 - theta2 - - def add_difference(theta0, theta1_2_diff, alpha): - return theta0 + (alpha * theta1_2_diff) - - def filename_weighted_sum(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - Ma = round(1 - multiplier, 2) - Mb = round(multiplier, 2) - - return f"{Ma}({a}) + {Mb}({b})" - - def filename_add_difference(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - c = tertiary_model_info.model_name - M = round(multiplier, 2) - - return f"{a} + {M}({b} - {c})" - - def filename_nothing(): - return primary_model_info.model_name - - theta_funcs = { - "Weighted sum": (filename_weighted_sum, None, weighted_sum), - "Add difference": (filename_add_difference, get_difference, add_difference), - "No interpolation": (filename_nothing, None, None), - } - filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] - shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) - - if not primary_model_name: - return fail("Failed: Merging requires a primary model.") - - primary_model_info = sd_models.checkpoints_list[primary_model_name] - - if theta_func2 and not secondary_model_name: - return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None - - if theta_func1 and not tertiary_model_name: - return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None - - result_is_inpainting_model = False - - if theta_func2: - shared.state.textinfo = f"Loading B" - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - else: - theta_1 = None - - if theta_func1: - shared.state.textinfo = f"Loading C" - print(f"Loading {tertiary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') - - shared.state.textinfo = 'Merging B and C' - shared.state.sampling_steps = len(theta_1.keys()) - for key in tqdm.tqdm(theta_1.keys()): - if key in checkpoint_dict_skip_on_merge: - continue - - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - - shared.state.sampling_step += 1 - del theta_2 - - shared.state.nextjob() - - shared.state.textinfo = f"Loading {primary_model_info.filename}..." - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print("Merging...") - shared.state.textinfo = 'Merging A and B' - shared.state.sampling_steps = len(theta_0.keys()) - for key in tqdm.tqdm(theta_0.keys()): - if theta_1 and 'model' in key and key in theta_1: - - if key in checkpoint_dict_skip_on_merge: - continue - - a = theta_0[key] - b = theta_1[key] - - # this enables merging an inpainting model (A) with another one (B); - # where normal model would have 4 channels, for latenst space, inpainting model would - # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 - if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: - if a.shape[1] == 4 and b.shape[1] == 9: - raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True - else: - theta_0[key] = theta_func2(a, b, multiplier) - - theta_0[key] = to_half(theta_0[key], save_as_half) - - shared.state.sampling_step += 1 - - del theta_1 - - bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) - if bake_in_vae_filename is not None: - print(f"Baking in VAE from {bake_in_vae_filename}") - shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - - for key in vae_dict.keys(): - theta_0_key = 'first_stage_model.' + key - if theta_0_key in theta_0: - theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) - - del vae_dict - - if save_as_half and not theta_func2: - for key in theta_0.keys(): - theta_0[key] = to_half(theta_0[key], save_as_half) - - if discard_weights: - regex = re.compile(discard_weights) - for key in list(theta_0): - if re.search(regex, key): - theta_0.pop(key, None) - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - - filename = filename_generator() if custom_name == '' else custom_name - filename += ".inpainting" if result_is_inpainting_model else "" - filename += "." + checkpoint_format - - output_modelname = os.path.join(ckpt_dir, filename) - - shared.state.nextjob() - shared.state.textinfo = "Saving" - print(f"Saving to {output_modelname}...") - - _, extension = os.path.splitext(output_modelname) - if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) - else: - torch.save(theta_0, output_modelname) - - sd_models.list_models() - - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - - print(f"Checkpoint saved to {output_modelname}.") - shared.state.textinfo = "Checkpoint saved" - shared.state.end() - - return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/temp b/modules/temp new file mode 100644 index 00000000..385430dc --- /dev/null +++ b/modules/temp @@ -0,0 +1,466 @@ +from __future__ import annotations +import math +import os +import re +import sys +import traceback +import shutil + +import numpy as np +from PIL import Image + +import torch +import tqdm + +from typing import Callable, List, OrderedDict, Tuple +from functools import partial +from dataclasses import dataclass + +from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae +from modules.shared import opts +import modules.gfpgan_model +from modules.ui import plaintext_to_html +import modules.codeformer_model +import gradio as gr +import safetensors.torch + +class LruCache(OrderedDict): + @dataclass(frozen=True) + class Key: + image_hash: int + info_hash: int + args_hash: int + + @dataclass + class Value: + image: Image.Image + info: str + + def __init__(self, max_size: int = 5, *args, **kwargs): + super().__init__(*args, **kwargs) + self._max_size = max_size + + def get(self, key: LruCache.Key) -> LruCache.Value: + ret = super().get(key) + if ret is not None: + self.move_to_end(key) # Move to end of eviction list + return ret + + def put(self, key: LruCache.Key, value: LruCache.Value) -> None: + self[key] = value + while len(self) > self._max_size: + self.popitem(last=False) + + +cached_images: LruCache = LruCache(max_size=5) + + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + devices.torch_gc() + + shared.state.begin() + shared.state.job = 'extras' + + imageArr = [] + # Also keep track of original file names + imageNameArr = [] + outputs = [] + + if extras_mode == 1: + #convert file to pillow image + for img in image_folder: + image = Image.open(img) + imageArr.append(image) + imageNameArr.append(os.path.splitext(img.orig_name)[0]) + elif extras_mode == 2: + assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + + if input_dir == '': + return outputs, "Please select an input directory.", '' + image_list = shared.listfiles(input_dir) + for img in image_list: + try: + image = Image.open(img) + except Exception: + continue + imageArr.append(image) + imageNameArr.append(img) + else: + imageArr.append(image) + imageNameArr.append(None) + + if extras_mode == 2 and output_dir != '': + outpath = output_dir + else: + outpath = opts.outdir_samples or opts.outdir_extras_samples + + # Extra operation definitions + + def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-gfpgan' + restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(image, res, gfpgan_visibility) + + info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" + return (res, info) + + def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-codeformer' + restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(image, res, codeformer_visibility) + + info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" + return (res, info) + + def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): + shared.state.job = 'extras-upscale' + upscaler = shared.sd_upscalers[scaler_index] + res = upscaler.scaler.upscale(image, resize, upscaler.data_path) + if mode == 1 and crop: + cropped = Image.new("RGB", (resize_w, resize_h)) + cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) + res = cropped + return res + + def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text + nonlocal upscaling_resize + if resize_mode == 1: + upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) + crop_info = " (crop)" if upscaling_crop else "" + info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" + return (image, info) + + @dataclass + class UpscaleParams: + upscaler_idx: int + blend_alpha: float + + def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: + blended_result: Image.Image = None + image_hash: str = hash(np.array(image.getdata()).tobytes()) + for upscaler in params: + upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, + upscaling_resize_w, upscaling_resize_h, upscaling_crop) + cache_key = LruCache.Key(image_hash=image_hash, + info_hash=hash(info), + args_hash=hash(upscale_args)) + cached_entry = cached_images.get(cache_key) + if cached_entry is None: + res = upscale(image, *upscale_args) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + cached_images.put(cache_key, LruCache.Value(image=res, info=info)) + else: + res, info = cached_entry.image, cached_entry.info + + if blended_result is None: + blended_result = res + else: + blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) + return (blended_result, info) + + # Build a list of operations to run + facefix_ops: List[Callable] = [] + facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] + facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] + + upscale_ops: List[Callable] = [] + upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] + + if upscaling_resize != 0: + step_params: List[UpscaleParams] = [] + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) + + upscale_ops.append(partial(run_upscalers_blend, step_params)) + + extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) + + for image, image_name in zip(imageArr, imageNameArr): + if image is None: + return outputs, "Please select an input image.", '' + + shared.state.textinfo = f'Processing image {image_name}' + + existing_pnginfo = image.info or {} + + image = image.convert("RGB") + info = "" + # Run each operation on each image + for op in extras_ops: + image, info = op(image, info) + + if opts.use_original_name_batch and image_name is not None: + basename = os.path.splitext(os.path.basename(image_name))[0] + else: + basename = '' + + if opts.enable_pnginfo: # append info before save + image.info = existing_pnginfo + image.info["extras"] = info + + if save_output: + # Add upscaler name as a suffix. + suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" + # Add second upscaler if applicable. + if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: + suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" + + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) + + if extras_mode != 2 or show_extras_results : + outputs.append(image) + + devices.torch_gc() + + return outputs, plaintext_to_html(info), '' + +def clear_cache(): + cached_images.clear() + + +def run_pnginfo(image): + if image is None: + return '', '', '' + + geninfo, items = images.read_info_from_image(image) + items = {**{'parameters': geninfo}, **items} + + info = '' + for key, text in items.items(): + info += f""" +
      +

      {plaintext_to_html(str(key))}

      +

      {plaintext_to_html(str(text))}

      +
      +""".strip()+"\n" + + if len(info) == 0: + message = "Nothing found in the image." + info = f"

      {message}

      " + + return '', geninfo, info + + +def create_config(ckpt_result, config_source, a, b, c): + def config(x): + res = sd_models.find_checkpoint_config(x) if x else None + return res if res != shared.sd_default_config else None + + if config_source == 0: + cfg = config(a) or config(b) or config(c) + elif config_source == 1: + cfg = config(b) + elif config_source == 2: + cfg = config(c) + else: + cfg = None + + if cfg is None: + return + + filename, _ = os.path.splitext(ckpt_result) + checkpoint_filename = filename + ".yaml" + + print("Copying config:") + print(" from:", cfg) + print(" to:", checkpoint_filename) + shutil.copyfile(cfg, checkpoint_filename) + + +checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + + +def to_half(tensor, enable): + if enable and tensor.dtype == torch.float: + return tensor.half() + + return tensor + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): + shared.state.begin() + shared.state.job = 'model-merge' + + def fail(message): + shared.state.textinfo = message + shared.state.end() + return [*[gr.update() for _ in range(4)], message] + + def weighted_sum(theta0, theta1, alpha): + return ((1 - alpha) * theta0) + (alpha * theta1) + + def get_difference(theta1, theta2): + return theta1 - theta2 + + def add_difference(theta0, theta1_2_diff, alpha): + return theta0 + (alpha * theta1_2_diff) + + def filename_weighted_sum(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + Ma = round(1 - multiplier, 2) + Mb = round(multiplier, 2) + + return f"{Ma}({a}) + {Mb}({b})" + + def filename_add_difference(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + c = tertiary_model_info.model_name + M = round(multiplier, 2) + + return f"{a} + {M}({b} - {c})" + + def filename_nothing(): + return primary_model_info.model_name + + theta_funcs = { + "Weighted sum": (filename_weighted_sum, None, weighted_sum), + "Add difference": (filename_add_difference, get_difference, add_difference), + "No interpolation": (filename_nothing, None, None), + } + filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] + shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) + + if not primary_model_name: + return fail("Failed: Merging requires a primary model.") + + primary_model_info = sd_models.checkpoints_list[primary_model_name] + + if theta_func2 and not secondary_model_name: + return fail("Failed: Merging requires a secondary model.") + + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None + + if theta_func1 and not tertiary_model_name: + return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") + + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None + + result_is_inpainting_model = False + + if theta_func2: + shared.state.textinfo = f"Loading B" + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + else: + theta_1 = None + + if theta_func1: + shared.state.textinfo = f"Loading C" + print(f"Loading {tertiary_model_info.filename}...") + theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + + shared.state.textinfo = 'Merging B and C' + shared.state.sampling_steps = len(theta_1.keys()) + for key in tqdm.tqdm(theta_1.keys()): + if key in checkpoint_dict_skip_on_merge: + continue + + if 'model' in key: + if key in theta_2: + t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) + theta_1[key] = theta_func1(theta_1[key], t2) + else: + theta_1[key] = torch.zeros_like(theta_1[key]) + + shared.state.sampling_step += 1 + del theta_2 + + shared.state.nextjob() + + shared.state.textinfo = f"Loading {primary_model_info.filename}..." + print(f"Loading {primary_model_info.filename}...") + theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') + + print("Merging...") + shared.state.textinfo = 'Merging A and B' + shared.state.sampling_steps = len(theta_0.keys()) + for key in tqdm.tqdm(theta_0.keys()): + if theta_1 and 'model' in key and key in theta_1: + + if key in checkpoint_dict_skip_on_merge: + continue + + a = theta_0[key] + b = theta_1[key] + + # this enables merging an inpainting model (A) with another one (B); + # where normal model would have 4 channels, for latenst space, inpainting model would + # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 + if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: + if a.shape[1] == 4 and b.shape[1] == 9: + raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") + + assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" + + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) + result_is_inpainting_model = True + else: + theta_0[key] = theta_func2(a, b, multiplier) + + theta_0[key] = to_half(theta_0[key], save_as_half) + + shared.state.sampling_step += 1 + + del theta_1 + + bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) + if bake_in_vae_filename is not None: + print(f"Baking in VAE from {bake_in_vae_filename}") + shared.state.textinfo = 'Baking in VAE' + vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') + + for key in vae_dict.keys(): + theta_0_key = 'first_stage_model.' + key + if theta_0_key in theta_0: + theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) + + del vae_dict + + if save_as_half and not theta_func2: + for key in theta_0.keys(): + theta_0[key] = to_half(theta_0[key], save_as_half) + + if discard_weights: + regex = re.compile(discard_weights) + for key in list(theta_0): + if re.search(regex, key): + theta_0.pop(key, None) + + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path + + filename = filename_generator() if custom_name == '' else custom_name + filename += ".inpainting" if result_is_inpainting_model else "" + filename += "." + checkpoint_format + + output_modelname = os.path.join(ckpt_dir, filename) + + shared.state.nextjob() + shared.state.textinfo = "Saving" + print(f"Saving to {output_modelname}...") + + _, extension = os.path.splitext(output_modelname) + if extension.lower() == ".safetensors": + safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) + else: + torch.save(theta_0, output_modelname) + + sd_models.list_models() + + create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + + print(f"Checkpoint saved to {output_modelname}.") + shared.state.textinfo = "Checkpoint saved" + shared.state.end() + + return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] -- cgit v1.2.3 From c56b36712289020a98f0c77794b9045a251ecd55 Mon Sep 17 00:00:00 2001 From: Andrey <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:26:41 +0300 Subject: Split history extras.py to postprocessing.py --- modules/extras.py | 466 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ modules/temp | 466 ------------------------------------------------------ 2 files changed, 466 insertions(+), 466 deletions(-) create mode 100644 modules/extras.py delete mode 100644 modules/temp diff --git a/modules/extras.py b/modules/extras.py new file mode 100644 index 00000000..385430dc --- /dev/null +++ b/modules/extras.py @@ -0,0 +1,466 @@ +from __future__ import annotations +import math +import os +import re +import sys +import traceback +import shutil + +import numpy as np +from PIL import Image + +import torch +import tqdm + +from typing import Callable, List, OrderedDict, Tuple +from functools import partial +from dataclasses import dataclass + +from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae +from modules.shared import opts +import modules.gfpgan_model +from modules.ui import plaintext_to_html +import modules.codeformer_model +import gradio as gr +import safetensors.torch + +class LruCache(OrderedDict): + @dataclass(frozen=True) + class Key: + image_hash: int + info_hash: int + args_hash: int + + @dataclass + class Value: + image: Image.Image + info: str + + def __init__(self, max_size: int = 5, *args, **kwargs): + super().__init__(*args, **kwargs) + self._max_size = max_size + + def get(self, key: LruCache.Key) -> LruCache.Value: + ret = super().get(key) + if ret is not None: + self.move_to_end(key) # Move to end of eviction list + return ret + + def put(self, key: LruCache.Key, value: LruCache.Value) -> None: + self[key] = value + while len(self) > self._max_size: + self.popitem(last=False) + + +cached_images: LruCache = LruCache(max_size=5) + + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + devices.torch_gc() + + shared.state.begin() + shared.state.job = 'extras' + + imageArr = [] + # Also keep track of original file names + imageNameArr = [] + outputs = [] + + if extras_mode == 1: + #convert file to pillow image + for img in image_folder: + image = Image.open(img) + imageArr.append(image) + imageNameArr.append(os.path.splitext(img.orig_name)[0]) + elif extras_mode == 2: + assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + + if input_dir == '': + return outputs, "Please select an input directory.", '' + image_list = shared.listfiles(input_dir) + for img in image_list: + try: + image = Image.open(img) + except Exception: + continue + imageArr.append(image) + imageNameArr.append(img) + else: + imageArr.append(image) + imageNameArr.append(None) + + if extras_mode == 2 and output_dir != '': + outpath = output_dir + else: + outpath = opts.outdir_samples or opts.outdir_extras_samples + + # Extra operation definitions + + def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-gfpgan' + restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(image, res, gfpgan_visibility) + + info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" + return (res, info) + + def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + shared.state.job = 'extras-codeformer' + restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(image, res, codeformer_visibility) + + info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" + return (res, info) + + def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): + shared.state.job = 'extras-upscale' + upscaler = shared.sd_upscalers[scaler_index] + res = upscaler.scaler.upscale(image, resize, upscaler.data_path) + if mode == 1 and crop: + cropped = Image.new("RGB", (resize_w, resize_h)) + cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) + res = cropped + return res + + def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: + # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text + nonlocal upscaling_resize + if resize_mode == 1: + upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) + crop_info = " (crop)" if upscaling_crop else "" + info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" + return (image, info) + + @dataclass + class UpscaleParams: + upscaler_idx: int + blend_alpha: float + + def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: + blended_result: Image.Image = None + image_hash: str = hash(np.array(image.getdata()).tobytes()) + for upscaler in params: + upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, + upscaling_resize_w, upscaling_resize_h, upscaling_crop) + cache_key = LruCache.Key(image_hash=image_hash, + info_hash=hash(info), + args_hash=hash(upscale_args)) + cached_entry = cached_images.get(cache_key) + if cached_entry is None: + res = upscale(image, *upscale_args) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" + cached_images.put(cache_key, LruCache.Value(image=res, info=info)) + else: + res, info = cached_entry.image, cached_entry.info + + if blended_result is None: + blended_result = res + else: + blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) + return (blended_result, info) + + # Build a list of operations to run + facefix_ops: List[Callable] = [] + facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] + facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] + + upscale_ops: List[Callable] = [] + upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] + + if upscaling_resize != 0: + step_params: List[UpscaleParams] = [] + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) + + upscale_ops.append(partial(run_upscalers_blend, step_params)) + + extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) + + for image, image_name in zip(imageArr, imageNameArr): + if image is None: + return outputs, "Please select an input image.", '' + + shared.state.textinfo = f'Processing image {image_name}' + + existing_pnginfo = image.info or {} + + image = image.convert("RGB") + info = "" + # Run each operation on each image + for op in extras_ops: + image, info = op(image, info) + + if opts.use_original_name_batch and image_name is not None: + basename = os.path.splitext(os.path.basename(image_name))[0] + else: + basename = '' + + if opts.enable_pnginfo: # append info before save + image.info = existing_pnginfo + image.info["extras"] = info + + if save_output: + # Add upscaler name as a suffix. + suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" + # Add second upscaler if applicable. + if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: + suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" + + images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, + no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) + + if extras_mode != 2 or show_extras_results : + outputs.append(image) + + devices.torch_gc() + + return outputs, plaintext_to_html(info), '' + +def clear_cache(): + cached_images.clear() + + +def run_pnginfo(image): + if image is None: + return '', '', '' + + geninfo, items = images.read_info_from_image(image) + items = {**{'parameters': geninfo}, **items} + + info = '' + for key, text in items.items(): + info += f""" +
      +

      {plaintext_to_html(str(key))}

      +

      {plaintext_to_html(str(text))}

      +
      +""".strip()+"\n" + + if len(info) == 0: + message = "Nothing found in the image." + info = f"

      {message}

      " + + return '', geninfo, info + + +def create_config(ckpt_result, config_source, a, b, c): + def config(x): + res = sd_models.find_checkpoint_config(x) if x else None + return res if res != shared.sd_default_config else None + + if config_source == 0: + cfg = config(a) or config(b) or config(c) + elif config_source == 1: + cfg = config(b) + elif config_source == 2: + cfg = config(c) + else: + cfg = None + + if cfg is None: + return + + filename, _ = os.path.splitext(ckpt_result) + checkpoint_filename = filename + ".yaml" + + print("Copying config:") + print(" from:", cfg) + print(" to:", checkpoint_filename) + shutil.copyfile(cfg, checkpoint_filename) + + +checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + + +def to_half(tensor, enable): + if enable and tensor.dtype == torch.float: + return tensor.half() + + return tensor + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): + shared.state.begin() + shared.state.job = 'model-merge' + + def fail(message): + shared.state.textinfo = message + shared.state.end() + return [*[gr.update() for _ in range(4)], message] + + def weighted_sum(theta0, theta1, alpha): + return ((1 - alpha) * theta0) + (alpha * theta1) + + def get_difference(theta1, theta2): + return theta1 - theta2 + + def add_difference(theta0, theta1_2_diff, alpha): + return theta0 + (alpha * theta1_2_diff) + + def filename_weighted_sum(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + Ma = round(1 - multiplier, 2) + Mb = round(multiplier, 2) + + return f"{Ma}({a}) + {Mb}({b})" + + def filename_add_difference(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + c = tertiary_model_info.model_name + M = round(multiplier, 2) + + return f"{a} + {M}({b} - {c})" + + def filename_nothing(): + return primary_model_info.model_name + + theta_funcs = { + "Weighted sum": (filename_weighted_sum, None, weighted_sum), + "Add difference": (filename_add_difference, get_difference, add_difference), + "No interpolation": (filename_nothing, None, None), + } + filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] + shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) + + if not primary_model_name: + return fail("Failed: Merging requires a primary model.") + + primary_model_info = sd_models.checkpoints_list[primary_model_name] + + if theta_func2 and not secondary_model_name: + return fail("Failed: Merging requires a secondary model.") + + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None + + if theta_func1 and not tertiary_model_name: + return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") + + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None + + result_is_inpainting_model = False + + if theta_func2: + shared.state.textinfo = f"Loading B" + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + else: + theta_1 = None + + if theta_func1: + shared.state.textinfo = f"Loading C" + print(f"Loading {tertiary_model_info.filename}...") + theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + + shared.state.textinfo = 'Merging B and C' + shared.state.sampling_steps = len(theta_1.keys()) + for key in tqdm.tqdm(theta_1.keys()): + if key in checkpoint_dict_skip_on_merge: + continue + + if 'model' in key: + if key in theta_2: + t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) + theta_1[key] = theta_func1(theta_1[key], t2) + else: + theta_1[key] = torch.zeros_like(theta_1[key]) + + shared.state.sampling_step += 1 + del theta_2 + + shared.state.nextjob() + + shared.state.textinfo = f"Loading {primary_model_info.filename}..." + print(f"Loading {primary_model_info.filename}...") + theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') + + print("Merging...") + shared.state.textinfo = 'Merging A and B' + shared.state.sampling_steps = len(theta_0.keys()) + for key in tqdm.tqdm(theta_0.keys()): + if theta_1 and 'model' in key and key in theta_1: + + if key in checkpoint_dict_skip_on_merge: + continue + + a = theta_0[key] + b = theta_1[key] + + # this enables merging an inpainting model (A) with another one (B); + # where normal model would have 4 channels, for latenst space, inpainting model would + # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 + if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: + if a.shape[1] == 4 and b.shape[1] == 9: + raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") + + assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" + + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) + result_is_inpainting_model = True + else: + theta_0[key] = theta_func2(a, b, multiplier) + + theta_0[key] = to_half(theta_0[key], save_as_half) + + shared.state.sampling_step += 1 + + del theta_1 + + bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) + if bake_in_vae_filename is not None: + print(f"Baking in VAE from {bake_in_vae_filename}") + shared.state.textinfo = 'Baking in VAE' + vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') + + for key in vae_dict.keys(): + theta_0_key = 'first_stage_model.' + key + if theta_0_key in theta_0: + theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) + + del vae_dict + + if save_as_half and not theta_func2: + for key in theta_0.keys(): + theta_0[key] = to_half(theta_0[key], save_as_half) + + if discard_weights: + regex = re.compile(discard_weights) + for key in list(theta_0): + if re.search(regex, key): + theta_0.pop(key, None) + + ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path + + filename = filename_generator() if custom_name == '' else custom_name + filename += ".inpainting" if result_is_inpainting_model else "" + filename += "." + checkpoint_format + + output_modelname = os.path.join(ckpt_dir, filename) + + shared.state.nextjob() + shared.state.textinfo = "Saving" + print(f"Saving to {output_modelname}...") + + _, extension = os.path.splitext(output_modelname) + if extension.lower() == ".safetensors": + safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) + else: + torch.save(theta_0, output_modelname) + + sd_models.list_models() + + create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) + + print(f"Checkpoint saved to {output_modelname}.") + shared.state.textinfo = "Checkpoint saved" + shared.state.end() + + return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/temp b/modules/temp deleted file mode 100644 index 385430dc..00000000 --- a/modules/temp +++ /dev/null @@ -1,466 +0,0 @@ -from __future__ import annotations -import math -import os -import re -import sys -import traceback -import shutil - -import numpy as np -from PIL import Image - -import torch -import tqdm - -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model -from modules.ui import plaintext_to_html -import modules.codeformer_model -import gradio as gr -import safetensors.torch - -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - - -def run_pnginfo(image): - if image is None: - return '', '', '' - - geninfo, items = images.read_info_from_image(image) - items = {**{'parameters': geninfo}, **items} - - info = '' - for key, text in items.items(): - info += f""" -
      -

      {plaintext_to_html(str(key))}

      -

      {plaintext_to_html(str(text))}

      -
      -""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

      {message}

      " - - return '', geninfo, info - - -def create_config(ckpt_result, config_source, a, b, c): - def config(x): - res = sd_models.find_checkpoint_config(x) if x else None - return res if res != shared.sd_default_config else None - - if config_source == 0: - cfg = config(a) or config(b) or config(c) - elif config_source == 1: - cfg = config(b) - elif config_source == 2: - cfg = config(c) - else: - cfg = None - - if cfg is None: - return - - filename, _ = os.path.splitext(ckpt_result) - checkpoint_filename = filename + ".yaml" - - print("Copying config:") - print(" from:", cfg) - print(" to:", checkpoint_filename) - shutil.copyfile(cfg, checkpoint_filename) - - -checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - - -def to_half(tensor, enable): - if enable and tensor.dtype == torch.float: - return tensor.half() - - return tensor - - -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): - shared.state.begin() - shared.state.job = 'model-merge' - - def fail(message): - shared.state.textinfo = message - shared.state.end() - return [*[gr.update() for _ in range(4)], message] - - def weighted_sum(theta0, theta1, alpha): - return ((1 - alpha) * theta0) + (alpha * theta1) - - def get_difference(theta1, theta2): - return theta1 - theta2 - - def add_difference(theta0, theta1_2_diff, alpha): - return theta0 + (alpha * theta1_2_diff) - - def filename_weighted_sum(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - Ma = round(1 - multiplier, 2) - Mb = round(multiplier, 2) - - return f"{Ma}({a}) + {Mb}({b})" - - def filename_add_difference(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - c = tertiary_model_info.model_name - M = round(multiplier, 2) - - return f"{a} + {M}({b} - {c})" - - def filename_nothing(): - return primary_model_info.model_name - - theta_funcs = { - "Weighted sum": (filename_weighted_sum, None, weighted_sum), - "Add difference": (filename_add_difference, get_difference, add_difference), - "No interpolation": (filename_nothing, None, None), - } - filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] - shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) - - if not primary_model_name: - return fail("Failed: Merging requires a primary model.") - - primary_model_info = sd_models.checkpoints_list[primary_model_name] - - if theta_func2 and not secondary_model_name: - return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None - - if theta_func1 and not tertiary_model_name: - return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None - - result_is_inpainting_model = False - - if theta_func2: - shared.state.textinfo = f"Loading B" - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - else: - theta_1 = None - - if theta_func1: - shared.state.textinfo = f"Loading C" - print(f"Loading {tertiary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') - - shared.state.textinfo = 'Merging B and C' - shared.state.sampling_steps = len(theta_1.keys()) - for key in tqdm.tqdm(theta_1.keys()): - if key in checkpoint_dict_skip_on_merge: - continue - - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - - shared.state.sampling_step += 1 - del theta_2 - - shared.state.nextjob() - - shared.state.textinfo = f"Loading {primary_model_info.filename}..." - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print("Merging...") - shared.state.textinfo = 'Merging A and B' - shared.state.sampling_steps = len(theta_0.keys()) - for key in tqdm.tqdm(theta_0.keys()): - if theta_1 and 'model' in key and key in theta_1: - - if key in checkpoint_dict_skip_on_merge: - continue - - a = theta_0[key] - b = theta_1[key] - - # this enables merging an inpainting model (A) with another one (B); - # where normal model would have 4 channels, for latenst space, inpainting model would - # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 - if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: - if a.shape[1] == 4 and b.shape[1] == 9: - raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True - else: - theta_0[key] = theta_func2(a, b, multiplier) - - theta_0[key] = to_half(theta_0[key], save_as_half) - - shared.state.sampling_step += 1 - - del theta_1 - - bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) - if bake_in_vae_filename is not None: - print(f"Baking in VAE from {bake_in_vae_filename}") - shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - - for key in vae_dict.keys(): - theta_0_key = 'first_stage_model.' + key - if theta_0_key in theta_0: - theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) - - del vae_dict - - if save_as_half and not theta_func2: - for key in theta_0.keys(): - theta_0[key] = to_half(theta_0[key], save_as_half) - - if discard_weights: - regex = re.compile(discard_weights) - for key in list(theta_0): - if re.search(regex, key): - theta_0.pop(key, None) - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - - filename = filename_generator() if custom_name == '' else custom_name - filename += ".inpainting" if result_is_inpainting_model else "" - filename += "." + checkpoint_format - - output_modelname = os.path.join(ckpt_dir, filename) - - shared.state.nextjob() - shared.state.textinfo = "Saving" - print(f"Saving to {output_modelname}...") - - _, extension = os.path.splitext(output_modelname) - if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) - else: - torch.save(theta_0, output_modelname) - - sd_models.list_models() - - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - - print(f"Checkpoint saved to {output_modelname}.") - shared.state.textinfo = "Checkpoint saved" - shared.state.end() - - return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] -- cgit v1.2.3 From 68303c96e5ab31576a8238a24bf5b6191cf16ed1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:38:39 +0300 Subject: split oversize extras.py to postprocessing.py --- modules/extras.py | 217 +------------------------------------- modules/postprocessing.py | 257 +--------------------------------------------- modules/ui.py | 10 +- modules/ui_components.py | 7 ++ webui.py | 1 - 5 files changed, 18 insertions(+), 474 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 385430dc..f04ddfc2 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,231 +1,16 @@ -from __future__ import annotations -import math import os import re -import sys -import traceback import shutil -import numpy as np -from PIL import Image import torch import tqdm -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model +from modules import shared, images, sd_models, sd_vae from modules.ui import plaintext_to_html -import modules.codeformer_model import gradio as gr import safetensors.torch -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - def run_pnginfo(image): if image is None: diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 385430dc..cb85720b 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -1,28 +1,18 @@ from __future__ import annotations -import math import os -import re -import sys -import traceback -import shutil import numpy as np from PIL import Image -import torch -import tqdm - from typing import Callable, List, OrderedDict, Tuple from functools import partial from dataclasses import dataclass -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae +from modules import shared, images, devices, ui_components from modules.shared import opts import modules.gfpgan_model -from modules.ui import plaintext_to_html import modules.codeformer_model -import gradio as gr -import safetensors.torch + class LruCache(OrderedDict): @dataclass(frozen=True) @@ -55,7 +45,7 @@ class LruCache(OrderedDict): cached_images: LruCache = LruCache(max_size=5) -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): +def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): devices.torch_gc() shared.state.begin() @@ -221,246 +211,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ devices.torch_gc() - return outputs, plaintext_to_html(info), '' + return outputs, ui_components.plaintext_to_html(info), '' + def clear_cache(): cached_images.clear() - -def run_pnginfo(image): - if image is None: - return '', '', '' - - geninfo, items = images.read_info_from_image(image) - items = {**{'parameters': geninfo}, **items} - - info = '' - for key, text in items.items(): - info += f""" -
      -

      {plaintext_to_html(str(key))}

      -

      {plaintext_to_html(str(text))}

      -
      -""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

      {message}

      " - - return '', geninfo, info - - -def create_config(ckpt_result, config_source, a, b, c): - def config(x): - res = sd_models.find_checkpoint_config(x) if x else None - return res if res != shared.sd_default_config else None - - if config_source == 0: - cfg = config(a) or config(b) or config(c) - elif config_source == 1: - cfg = config(b) - elif config_source == 2: - cfg = config(c) - else: - cfg = None - - if cfg is None: - return - - filename, _ = os.path.splitext(ckpt_result) - checkpoint_filename = filename + ".yaml" - - print("Copying config:") - print(" from:", cfg) - print(" to:", checkpoint_filename) - shutil.copyfile(cfg, checkpoint_filename) - - -checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - - -def to_half(tensor, enable): - if enable and tensor.dtype == torch.float: - return tensor.half() - - return tensor - - -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): - shared.state.begin() - shared.state.job = 'model-merge' - - def fail(message): - shared.state.textinfo = message - shared.state.end() - return [*[gr.update() for _ in range(4)], message] - - def weighted_sum(theta0, theta1, alpha): - return ((1 - alpha) * theta0) + (alpha * theta1) - - def get_difference(theta1, theta2): - return theta1 - theta2 - - def add_difference(theta0, theta1_2_diff, alpha): - return theta0 + (alpha * theta1_2_diff) - - def filename_weighted_sum(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - Ma = round(1 - multiplier, 2) - Mb = round(multiplier, 2) - - return f"{Ma}({a}) + {Mb}({b})" - - def filename_add_difference(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - c = tertiary_model_info.model_name - M = round(multiplier, 2) - - return f"{a} + {M}({b} - {c})" - - def filename_nothing(): - return primary_model_info.model_name - - theta_funcs = { - "Weighted sum": (filename_weighted_sum, None, weighted_sum), - "Add difference": (filename_add_difference, get_difference, add_difference), - "No interpolation": (filename_nothing, None, None), - } - filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] - shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) - - if not primary_model_name: - return fail("Failed: Merging requires a primary model.") - - primary_model_info = sd_models.checkpoints_list[primary_model_name] - - if theta_func2 and not secondary_model_name: - return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None - - if theta_func1 and not tertiary_model_name: - return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None - - result_is_inpainting_model = False - - if theta_func2: - shared.state.textinfo = f"Loading B" - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - else: - theta_1 = None - - if theta_func1: - shared.state.textinfo = f"Loading C" - print(f"Loading {tertiary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') - - shared.state.textinfo = 'Merging B and C' - shared.state.sampling_steps = len(theta_1.keys()) - for key in tqdm.tqdm(theta_1.keys()): - if key in checkpoint_dict_skip_on_merge: - continue - - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - - shared.state.sampling_step += 1 - del theta_2 - - shared.state.nextjob() - - shared.state.textinfo = f"Loading {primary_model_info.filename}..." - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print("Merging...") - shared.state.textinfo = 'Merging A and B' - shared.state.sampling_steps = len(theta_0.keys()) - for key in tqdm.tqdm(theta_0.keys()): - if theta_1 and 'model' in key and key in theta_1: - - if key in checkpoint_dict_skip_on_merge: - continue - - a = theta_0[key] - b = theta_1[key] - - # this enables merging an inpainting model (A) with another one (B); - # where normal model would have 4 channels, for latenst space, inpainting model would - # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 - if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: - if a.shape[1] == 4 and b.shape[1] == 9: - raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True - else: - theta_0[key] = theta_func2(a, b, multiplier) - - theta_0[key] = to_half(theta_0[key], save_as_half) - - shared.state.sampling_step += 1 - - del theta_1 - - bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) - if bake_in_vae_filename is not None: - print(f"Baking in VAE from {bake_in_vae_filename}") - shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - - for key in vae_dict.keys(): - theta_0_key = 'first_stage_model.' + key - if theta_0_key in theta_0: - theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) - - del vae_dict - - if save_as_half and not theta_func2: - for key in theta_0.keys(): - theta_0[key] = to_half(theta_0[key], save_as_half) - - if discard_weights: - regex = re.compile(discard_weights) - for key in list(theta_0): - if re.search(regex, key): - theta_0.pop(key, None) - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - - filename = filename_generator() if custom_name == '' else custom_name - filename += ".inpainting" if result_is_inpainting_model else "" - filename += "." + checkpoint_format - - output_modelname = os.path.join(ckpt_dir, filename) - - shared.state.nextjob() - shared.state.textinfo = "Saving" - print(f"Saving to {output_modelname}...") - - _, extension = os.path.splitext(output_modelname) - if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) - else: - torch.save(theta_0, output_modelname) - - sd_models.list_models() - - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - - print(f"Checkpoint saved to {output_modelname}.") - shared.state.textinfo = "Checkpoint saved" - shared.state.end() - - return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/ui.py b/modules/ui.py index eb4b7e6b..4116e167 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -95,8 +95,8 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): - text = "

      " + "
      \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

      " - return text + return ui_components.plaintext_to_html(text) + def send_gradio_gallery_to_image(x): if len(x) == 0: @@ -1152,7 +1152,7 @@ def create_ui(): result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), + fn=wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), _js="get_extras_tab_index", inputs=[ dummy_component, @@ -1183,7 +1183,7 @@ def create_ui(): parameters_copypaste.add_paste_fields("extras", extras_image, None) extras_image.change( - fn=modules.extras.clear_cache, + fn=postprocessing.clear_cache, inputs=[], outputs=[] ) diff --git a/modules/ui_components.py b/modules/ui_components.py index 46324425..989cc87b 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -1,3 +1,5 @@ +import html + import gradio as gr @@ -47,3 +49,8 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" + + +def plaintext_to_html(text): + text = "

      " + "
      \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

      " + return text diff --git a/webui.py b/webui.py index d235da74..7cf5885e 100644 --- a/webui.py +++ b/webui.py @@ -22,7 +22,6 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks import modules.codeformer_model as codeformer -import modules.extras import modules.face_restoration import modules.gfpgan_model as gfpgan import modules.img2img -- cgit v1.2.3 From 985c0b8e9abdd67734d638badefb6ea806b1f28b Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Sat, 21 Jan 2023 17:45:36 -0300 Subject: feat(extra-networks): add thumbs view style --- html/image-update.svg | 3 +++ javascript/extraNetworks.js | 2 ++ modules/ui_extra_networks.py | 21 ++++++++------- style.css | 64 +++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 78 insertions(+), 12 deletions(-) create mode 100644 html/image-update.svg diff --git a/html/image-update.svg b/html/image-update.svg new file mode 100644 index 00000000..525e4fc5 --- /dev/null +++ b/html/image-update.svg @@ -0,0 +1,3 @@ + + + diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index c5a9adb3..1bda7c6e 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -6,11 +6,13 @@ function setupExtraNetworksForTab(tabname){ var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea') var refresh = gradioApp().getElementById(tabname+'_extra_refresh') var close = gradioApp().getElementById(tabname+'_extra_close') + var view = gradioApp().getElementById(tabname+'_extra_view') search.classList.add('search') tabs.appendChild(search) tabs.appendChild(refresh) tabs.appendChild(close) + tabs.appendChild(view) search.addEventListener("input", function(evt){ searchTerm = search.value.toLowerCase() diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index af2b8071..ce4801b5 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -25,7 +25,7 @@ class ExtraNetworksPage: def refresh(self): pass - def create_html(self, tabname): + def create_html(self, tabname, view = 'cards'): items_html = '' for item in self.list_items(): @@ -36,7 +36,7 @@ class ExtraNetworksPage: items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) res = f""" -
      +
      {items_html}
      """ @@ -75,6 +75,7 @@ class ExtraNetworksUi: self.button_save_preview = None self.preview_target_filename = None + self.view_dropdown = None self.tabname = None @@ -110,6 +111,7 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value='cards') ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -117,16 +119,17 @@ def create_ui(container, button, tabname): button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) - def refresh(): + def refresh(view='cards'): res = [] for pg in ui.stored_extra_pages: pg.refresh() - res.append(pg.create_html(ui.tabname)) + res.append(pg.create_html(ui.tabname, view)) return res - button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + ui.view_dropdown.change(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) + button_refresh.click(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) return ui @@ -139,7 +142,7 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): - def save_preview(index, images, filename): + def save_preview(index, images, filename, view='cards'): if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] @@ -161,11 +164,11 @@ def setup_ui(ui, gallery): image.save(filename) - return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + return [page.create_html(ui.tabname, view) for page in ui.stored_extra_pages] ui.button_save_preview.click( fn=save_preview, - _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", - inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + _js="function(x, y, z, a){console.log(x, y, z, a); return [selected_gallery_index(), y, z, a]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename, ui.view_dropdown], outputs=[*ui.pages] ) diff --git a/style.css b/style.css index 507acec1..ca0a172b 100644 --- a/style.css +++ b/style.css @@ -784,21 +784,79 @@ footer { display: inline-block; max-width: 16em; margin: 0.3em; + align-self: center; } -.extra-network-cards .nocards{ +#txt2img_extra_view, #img2img_extra_view { + width: auto; +} + +.extra-network-cards .nocards, .extra-network-thumbs .nocards{ margin: 1.25em 0.5em 0.5em 0.5em; } -.extra-network-cards .nocards h1{ +.extra-network-cards .nocards h1, .extra-network-thumbs .nocards h1{ font-size: 1.5em; margin-bottom: 1em; } -.extra-network-cards .nocards li{ +.extra-network-cards .nocards li, .extra-network-thumbs .nocards li{ margin-left: 0.5em; } +.extra-network-thumbs { + display: flex; + flex-flow: row wrap; + gap: 10px; +} + +.extra-network-thumbs .card { + height: 6em; + width: 6em; + cursor: pointer; + background-image: url('./file=html/card-no-preview.png'); + background-size: cover; + background-position: center center; + position: relative; +} + +.extra-network-thumbs .card:hover .additional a { + display: block; +} + +.extra-network-thumbs .actions .additional a { + background-image: url('./file=html/image-update.svg'); + background-repeat: no-repeat; + background-size: cover; + background-position: center center; + position: absolute; + top: 0; + left: 0; + width: 24px; + height: 24px; + display: none; + font-size: 0; + text-align: -9999; + background-color: #fff; +} + +.extra-network-thumbs .actions .name { + position: absolute; + bottom: 0; + font-size: 10px; + padding: 3px; + width: 100%; + overflow: hidden; + white-space: nowrap; + text-overflow: ellipsis; + background: rgba(0,0,0,.5); +} + +.extra-network-thumbs .card:hover .actions .name { + white-space: normal; + word-break: break-all; +} + .extra-network-cards .card{ display: inline-block; margin: 0.5em; -- cgit v1.2.3 From 66eef11ce7f3db108225668c573cb4a763a43fb3 Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Sat, 21 Jan 2023 18:27:57 -0300 Subject: feat(extra-networks): add default view setting --- modules/shared.py | 4 ++++ modules/ui_extra_networks.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index cd78e50a..e9548864 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -430,6 +430,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "deepbooru_filter_tags": OptionInfo("", "filter out those tags from deepbooru output (separated by comma)"), })) +options_templates.update(options_section(('extra_networks', "Extra Networks"), { + "extra_networks_default_view": OptionInfo("cards", "Default view for Extra Networks", gr.Dropdown, { "choices": ["cards", "thumbs"] }), +})) + options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index ce4801b5..179ba47a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -25,7 +25,7 @@ class ExtraNetworksPage: def refresh(self): pass - def create_html(self, tabname, view = 'cards'): + def create_html(self, tabname, view=shared.opts.extra_networks_default_view): items_html = '' for item in self.list_items(): @@ -111,7 +111,7 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_close = gr.Button('Close', elem_id=tabname+"_extra_close") - ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value='cards') + ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value=lambda: shared.opts.extra_networks_default_view) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -119,7 +119,7 @@ def create_ui(container, button, tabname): button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) - def refresh(view='cards'): + def refresh(view): res = [] for pg in ui.stored_extra_pages: @@ -142,7 +142,7 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): - def save_preview(index, images, filename, view='cards'): + def save_preview(index, images, filename, view): if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] -- cgit v1.2.3 From 8a3f85c4cc1910f59e04b5c8355a30c4c42431e5 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sun, 22 Jan 2023 17:08:08 -0500 Subject: adds hires steps to x/y plot and fixes total_steps calculation --- scripts/xy_grid.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8ff315a7..5990b78d 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -184,6 +184,7 @@ axis_options = [ AxisOption("Var. seed", int, apply_field("subseed")), AxisOption("Var. strength", float, apply_field("subseed_strength")), AxisOption("Steps", int, apply_field("steps")), + AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), AxisOption("CFG Scale", float, apply_field("cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), @@ -427,10 +428,21 @@ class Script(scripts.Script): total_steps = p.steps * len(xs) * len(ys) if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: - total_steps *= 2 + if x_opt.label == "Hires steps": + total_steps += sum(xs) * len(ys) + elif y_opt.label == "Hires steps": + total_steps += sum(ys) * len(xs) + elif p.hr_second_pass_steps: + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) + else: + total_steps *= 2 + + total_steps *= p.n_iter - print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") - shared.total_tqdm.updateTotal(total_steps * p.n_iter) + image_cell_count = p.n_iter * p.batch_size + cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" + print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})") + shared.total_tqdm.updateTotal(total_steps) grid_infotext = [None] -- cgit v1.2.3 From f80ff3c1e444926879c284be9384a26ca38d4955 Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Sun, 22 Jan 2023 22:01:24 -0300 Subject: feat(extra-networks): remove view dropdown --- javascript/extraNetworks.js | 2 -- modules/ui_extra_networks.py | 20 +++++++++----------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 1bda7c6e..c5a9adb3 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -6,13 +6,11 @@ function setupExtraNetworksForTab(tabname){ var search = gradioApp().querySelector('#'+tabname+'_extra_search textarea') var refresh = gradioApp().getElementById(tabname+'_extra_refresh') var close = gradioApp().getElementById(tabname+'_extra_close') - var view = gradioApp().getElementById(tabname+'_extra_view') search.classList.add('search') tabs.appendChild(search) tabs.appendChild(refresh) tabs.appendChild(close) - tabs.appendChild(view) search.addEventListener("input", function(evt){ searchTerm = search.value.toLowerCase() diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 179ba47a..2ddac3d8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -25,7 +25,8 @@ class ExtraNetworksPage: def refresh(self): pass - def create_html(self, tabname, view=shared.opts.extra_networks_default_view): + def create_html(self, tabname): + view = shared.opts.extra_networks_default_view items_html = '' for item in self.list_items(): @@ -75,7 +76,6 @@ class ExtraNetworksUi: self.button_save_preview = None self.preview_target_filename = None - self.view_dropdown = None self.tabname = None @@ -111,7 +111,6 @@ def create_ui(container, button, tabname): filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") button_close = gr.Button('Close', elem_id=tabname+"_extra_close") - ui.view_dropdown = gr.Dropdown(['cards', 'thumbs'], elem_id=tabname+"_extra_view", label="View as", value=lambda: shared.opts.extra_networks_default_view) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) @@ -119,17 +118,16 @@ def create_ui(container, button, tabname): button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) - def refresh(view): + def refresh(): res = [] for pg in ui.stored_extra_pages: pg.refresh() - res.append(pg.create_html(ui.tabname, view)) + res.append(pg.create_html(ui.tabname)) return res - ui.view_dropdown.change(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) - button_refresh.click(fn=refresh, inputs=[ui.view_dropdown], outputs=ui.pages) + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui @@ -142,7 +140,7 @@ def path_is_parent(parent_path, child_path): def setup_ui(ui, gallery): - def save_preview(index, images, filename, view): + def save_preview(index, images, filename): if len(images) == 0: print("There is no image in gallery to save as a preview.") return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] @@ -164,11 +162,11 @@ def setup_ui(ui, gallery): image.save(filename) - return [page.create_html(ui.tabname, view) for page in ui.stored_extra_pages] + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] ui.button_save_preview.click( fn=save_preview, - _js="function(x, y, z, a){console.log(x, y, z, a); return [selected_gallery_index(), y, z, a]}", - inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename, ui.view_dropdown], + _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], outputs=[*ui.pages] ) -- cgit v1.2.3 From b5230197a69d36a79fdc4919c59a03e00e872dd3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 09:24:43 +0300 Subject: rework extras tab to use script system --- javascript/ui.js | 5 - modules/api/api.py | 13 +- modules/postprocessing.py | 236 ++++++++----------------------- modules/scripts.py | 28 ++-- modules/scripts_postprocessing.py | 147 +++++++++++++++++++ modules/shared.py | 5 + modules/ui.py | 265 +---------------------------------- modules/ui_common.py | 202 ++++++++++++++++++++++++++ modules/ui_components.py | 6 - modules/ui_postprocessing.py | 57 ++++++++ scripts/postprocessing_codeformer.py | 36 +++++ scripts/postprocessing_gfpgan.py | 33 +++++ scripts/postprocessing_upscale.py | 106 ++++++++++++++ 13 files changed, 675 insertions(+), 464 deletions(-) create mode 100644 modules/scripts_postprocessing.py create mode 100644 modules/ui_common.py create mode 100644 modules/ui_postprocessing.py create mode 100644 scripts/postprocessing_codeformer.py create mode 100644 scripts/postprocessing_gfpgan.py create mode 100644 scripts/postprocessing_upscale.py diff --git a/javascript/ui.js b/javascript/ui.js index 77256e15..ba72623c 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -104,11 +104,6 @@ function create_tab_index_args(tabId, args){ return res } -function get_extras_tab_index(){ - const [,,...args] = [...arguments] - return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args] -} - function get_img2img_tab_index() { let res = args_to_array(arguments) res.splice(-2) diff --git a/modules/api/api.py b/modules/api/api.py index f2e9e884..5d60fc0a 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images -from modules.extras import run_extras from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork @@ -45,10 +44,8 @@ def validate_sampler_name(name): def setUpscalers(req: dict): reqDict = vars(req) - reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) - reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2) - reqDict.pop('upscaler_1') - reqDict.pop('upscaler_2') + reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) + reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None) return reqDict def decode_base64_to_image(encoding): @@ -244,7 +241,7 @@ class Api: reqDict['image'] = decode_base64_to_image(reqDict['image']) with self.queue_lock: - result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1]) @@ -260,7 +257,7 @@ class Api: reqDict.pop('imageList') with self.queue_lock: - result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) + result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict) return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1]) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index cb85720b..8514fea7 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -1,219 +1,103 @@ -from __future__ import annotations import os -import numpy as np from PIL import Image -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import shared, images, devices, ui_components +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste from modules.shared import opts -import modules.gfpgan_model -import modules.codeformer_model - - -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) -def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): +def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): devices.torch_gc() shared.state.begin() shared.state.job = 'extras' - imageArr = [] - # Also keep track of original file names - imageNameArr = [] + image_data = [] + image_names = [] outputs = [] if extras_mode == 1: - #convert file to pillow image for img in image_folder: image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) + image_data.append(image) + image_names.append(os.path.splitext(img.orig_name)[0]) elif extras_mode == 2: assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' + assert input_dir, 'input directory not selected' - if input_dir == '': - return outputs, "Please select an input directory.", '' image_list = shared.listfiles(input_dir) - for img in image_list: + for filename in image_list: try: - image = Image.open(img) + image = Image.open(filename) except Exception: continue - imageArr.append(image) - imageNameArr.append(img) + image_data.append(image) + image_names.append(filename) else: - imageArr.append(image) - imageNameArr.append(None) + assert image, 'image not selected' + + image_data.append(image) + image_names.append(None) if extras_mode == 2 and output_dir != '': outpath = output_dir else: outpath = opts.outdir_samples or opts.outdir_extras_samples - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - + infotext = '' + + for image, name in zip(image_data, image_names): + shared.state.textinfo = name + existing_pnginfo = image.info or {} - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) + pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB")) - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] + scripts.scripts_postproc.run(pp, args) + + if opts.use_original_name_batch and name is not None: + basename = os.path.splitext(os.path.basename(name))[0] else: basename = '' - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info + infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" + if opts.enable_pnginfo: + pp.image.info = existing_pnginfo + pp.image.info["postprocessing"] = infotext - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) + if save_output: + images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) - if extras_mode != 2 or show_extras_results : - outputs.append(image) + if extras_mode != 2 or show_extras_results: + outputs.append(pp.image) devices.torch_gc() - return outputs, ui_components.plaintext_to_html(info), '' - - -def clear_cache(): - cached_images.clear() - + return outputs, ui_common.plaintext_to_html(infotext), '' + + +def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): + """old handler for API""" + + args = scripts.scripts_postproc.create_args_for_run({ + "Upscale": { + "upscale_mode": resize_mode, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + }, + "GFPGAN": { + "gfpgan_visibility": gfpgan_visibility, + }, + "CodeFormer": { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + }, + }) + + return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output) diff --git a/modules/scripts.py b/modules/scripts.py index 4ffc369b..03907a63 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,7 @@ from collections import namedtuple import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks, extensions, script_loading +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing AlwaysVisible = object() @@ -150,8 +150,10 @@ def basedir(): return current_basedir -scripts_data = [] ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) + +scripts_data = [] +postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) @@ -190,23 +192,31 @@ def list_files_with_name(filename): def load_scripts(): global current_basedir scripts_data.clear() + postprocessing_scripts_data.clear() script_callbacks.clear_callbacks() scripts_list = list_scripts("scripts", ".py") syspath = sys.path + def register_scripts_from_module(module): + for key, script_class in module.__dict__.items(): + if type(script_class) != type: + continue + + if issubclass(script_class, Script): + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): + postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + for scriptfile in sorted(scripts_list): try: if scriptfile.basedir != paths.script_path: sys.path = [scriptfile.basedir] + sys.path current_basedir = scriptfile.basedir - module = script_loading.load_module(scriptfile.path) - - for key, script_class in module.__dict__.items(): - if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) + script_module = script_loading.load_module(scriptfile.path) + register_scripts_from_module(script_module) except Exception: print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) @@ -413,6 +423,7 @@ class ScriptRunner: scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() +scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() scripts_current: ScriptRunner = None @@ -423,12 +434,13 @@ def reload_script_body_only(): def reload_scripts(): - global scripts_txt2img, scripts_img2img + global scripts_txt2img, scripts_img2img, scripts_postproc load_scripts() scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner() def IOComponent_init(self, *args, **kwargs): diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py new file mode 100644 index 00000000..25de02d0 --- /dev/null +++ b/modules/scripts_postprocessing.py @@ -0,0 +1,147 @@ +import os +import gradio as gr + +from modules import errors, shared + + +class PostprocessedImage: + def __init__(self, image): + self.image = image + self.info = {} + + +class ScriptPostprocessing: + filename = None + controls = None + args_from = None + args_to = None + + order = 1000 + """scripts will be ordred by this value in postprocessing UI""" + + name = None + """this function should return the title of the script.""" + + group = None + """A gr.Group component that has all script's UI inside it""" + + def ui(self): + """ + This function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be a dictionary that maps parameter names to components used in processing. + Values of those components will be passed to process() function. + """ + + pass + + def process(self, pp: PostprocessedImage, **args): + """ + This function is called to postprocess the image. + args contains a dictionary with all values returned by components from ui() + """ + + pass + + def image_changed(self): + pass + + +def wrap_call(func, filename, funcname, *args, default=None, **kwargs): + try: + res = func(*args, **kwargs) + return res + except Exception as e: + errors.display(e, f"calling {filename}/{funcname}") + + return default + + +class ScriptPostprocessingRunner: + def __init__(self): + self.scripts = None + self.ui_created = False + + def initialize_scripts(self, scripts_data): + self.scripts = [] + + for script_class, path, basedir, script_module in scripts_data: + script: ScriptPostprocessing = script_class() + script.filename = path + + self.scripts.append(script) + + def create_script_ui(self, script, inputs): + script.args_from = len(inputs) + script.args_to = len(inputs) + + script.controls = wrap_call(script.ui, script.filename, "ui") + + for control in script.controls.values(): + control.custom_script_source = os.path.basename(script.filename) + + inputs += list(script.controls.values()) + script.args_to = len(inputs) + + def scripts_in_preferred_order(self): + if self.scripts is None: + import modules.scripts + self.initialize_scripts(modules.scripts.postprocessing_scripts_data) + + scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] + + def script_score(name): + name = name.lower() + for i, possible_match in enumerate(scripts_order): + if possible_match in name: + return i + + return len(self.scripts) + + script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)} + + return sorted(self.scripts, key=lambda x: script_scores[x.name]) + + def setup_ui(self): + inputs = [] + + for script in self.scripts_in_preferred_order(): + with gr.Box() as group: + self.create_script_ui(script, inputs) + + script.group = group + + self.ui_created = True + return inputs + + def run(self, pp: PostprocessedImage, args): + for script in self.scripts_in_preferred_order(): + shared.state.job = script.name + + script_args = args[script.args_from:script.args_to] + + process_args = {} + for (name, component), value in zip(script.controls.items(), script_args): + process_args[name] = value + + script.process(pp, **process_args) + + def create_args_for_run(self, scripts_args): + if not self.ui_created: + with gr.Blocks(analytics_enabled=False): + self.setup_ui() + + scripts = self.scripts_in_preferred_order() + args = [None] * max([x.args_to for x in scripts]) + + for script in scripts: + script_args_dict = scripts_args.get(script.name, None) + if script_args_dict is not None: + + for i, name in enumerate(script.controls): + args[script.args_from + i] = script_args_dict.get(name, None) + + return args + + def image_changed(self): + for script in self.scripts_in_preferred_order(): + script.image_changed() diff --git a/modules/shared.py b/modules/shared.py index cd78e50a..cb73bf31 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -474,6 +474,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma"), })) +options_templates.update(options_section(('postprocessing', "Postprocessing"), { + 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), + 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), +})) + options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable those extensions"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), diff --git a/modules/ui.py b/modules/ui.py index 4116e167..8cb8e613 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -86,7 +86,6 @@ css_hide_progressbar = """ random_symbol = '\U0001f3b2\ufe0f' # 🎲️ reuse_symbol = '\u267b\ufe0f' # ♻️ paste_symbol = '\u2199\ufe0f' # ↙ -folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 @@ -95,7 +94,7 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): - return ui_components.plaintext_to_html(text) + return ui_common.plaintext_to_html(text) def send_gradio_gallery_to_image(x): @@ -103,70 +102,6 @@ def send_gradio_gallery_to_image(x): return None return image_from_url_text(x[0]) -def save_files(js_data, images, do_make_zip, index): - import csv - filenames = [] - fullfns = [] - - #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it - class MyObject: - def __init__(self, d=None): - if d is not None: - for key, value in d.items(): - setattr(self, key, value) - - data = json.loads(js_data) - - p = MyObject(data) - path = opts.outdir_save - save_to_dirs = opts.use_save_to_dirs_for_ui - extension: str = opts.samples_format - start_index = 0 - - if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - - images = [images[index]] - start_index = index - - os.makedirs(opts.outdir_save, exist_ok=True) - - with open(os.path.join(opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: - at_start = file.tell() == 0 - writer = csv.writer(file) - if at_start: - writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) - - for image_index, filedata in enumerate(images, start_index): - image = image_from_url_text(filedata) - - is_grid = image_index < p.index_of_first_image - i = 0 if is_grid else (image_index - p.index_of_first_image) - - fullfn, txt_fullfn = save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) - - filename = os.path.relpath(fullfn, path) - filenames.append(filename) - fullfns.append(fullfn) - if txt_fullfn: - filenames.append(os.path.basename(txt_fullfn)) - fullfns.append(txt_fullfn) - - writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) - - # Make Zip - if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") - - from zipfile import ZipFile - with ZipFile(zip_filepath, "w") as zip_file: - for i in range(len(fullfns)): - with open(fullfns[i], mode="rb") as f: - zip_file.writestr(filenames[i], f.read()) - fullfns.insert(0, zip_filepath) - - return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") - - def visit(x, func, path=""): if hasattr(x, 'children'): for c in x.children: @@ -444,19 +379,6 @@ def apply_setting(key, value): opts.save(shared.config_filename) return getattr(opts, key) - -def update_generation_info(generation_info, html_info, img_index): - try: - generation_info = json.loads(generation_info) - if img_index < 0 or img_index >= len(generation_info["infotexts"]): - return html_info, gr.update() - return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() - except Exception: - pass - # if the json parse or anything else fails, just return the old html_info - return html_info, gr.update() - - def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def refresh(): refresh_method() @@ -477,107 +399,7 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, ele def create_output_panel(tabname, outdir): - def open_folder(f): - if not os.path.exists(f): - print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') - return - elif not os.path.isdir(f): - print(f""" -WARNING -An open_folder request was made with an argument that is not a folder. -This could be an error or a malicious attempt to run code on your computer. -Requested path was: {f} -""", file=sys.stderr) - return - - if not shared.cmd_opts.hide_ui_dir_config: - path = os.path.normpath(f) - if platform.system() == "Windows": - os.startfile(path) - elif platform.system() == "Darwin": - sp.Popen(["open", path]) - elif "microsoft-standard-WSL2" in platform.uname().release: - sp.Popen(["wsl-open", path]) - else: - sp.Popen(["xdg-open", path]) - - with gr.Column(variant='panel', elem_id=f"{tabname}_results"): - with gr.Group(elem_id=f"{tabname}_gallery_container"): - result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) - - generation_info = None - with gr.Column(): - with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') - - if tabname != "extras": - save = gr.Button('Save', elem_id=f'save_{tabname}') - save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') - - buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) - - open_folder_button.click( - fn=lambda: open_folder(opts.outdir_samples or outdir), - inputs=[], - outputs=[], - ) - - if tabname != "extras": - with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') - - with gr.Group(): - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') - if tabname == 'txt2img' or tabname == 'img2img': - generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") - generation_info_button.click( - fn=update_generation_info, - _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", - inputs=[generation_info, html_info, html_info], - outputs=[html_info, html_info], - ) - - save.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ], - show_progress=False, - ) - - save_zip.click( - fn=wrap_gradio_call(save_files), - _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", - inputs=[ - generation_info, - result_gallery, - html_info, - html_info, - ], - outputs=[ - download_files, - html_log, - ] - ) - - else: - html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') - html_info = gr.HTML(elem_id=f'html_info_{tabname}') - html_log = gr.HTML(elem_id=f'html_log_{tabname}') - - parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) - return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + return ui_common.create_output_panel(tabname, outdir) def create_sampler_and_steps_selection(choices, tabname): @@ -1106,86 +928,7 @@ def create_ui(): modules.scripts.scripts_current = None with gr.Blocks(analytics_enabled=False) as extras_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='compact'): - with gr.Tabs(elem_id="mode_extras"): - with gr.TabItem('Single Image', elem_id="extras_single_tab"): - extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") - - with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab"): - image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") - - with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab"): - extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") - extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") - show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") - - submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') - - with gr.Tabs(elem_id="extras_resize_mode"): - with gr.TabItem('Scale by', elem_id="extras_scale_by_tab"): - upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") - with gr.TabItem('Scale to', elem_id="extras_scale_to_tab"): - with gr.Group(): - with gr.Row(): - upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") - upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") - upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") - - with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - - with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") - extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1, elem_id="extras_upscaler_2_visibility") - - with gr.Group(): - gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, interactive=modules.gfpgan_model.have_gfpgan, elem_id="extras_gfpgan_visibility") - - with gr.Group(): - codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_visibility") - codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer, elem_id="extras_codeformer_weight") - - with gr.Group(): - upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False, elem_id="extras_upscale_before_face_fix") - - result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) - - submit.click( - fn=wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), - _js="get_extras_tab_index", - inputs=[ - dummy_component, - dummy_component, - extras_image, - image_batch, - extras_batch_input_dir, - extras_batch_output_dir, - show_extras_results, - gfpgan_visibility, - codeformer_visibility, - codeformer_weight, - upscaling_resize, - upscaling_resize_w, - upscaling_resize_h, - upscaling_crop, - extras_upscaler_1, - extras_upscaler_2, - extras_upscaler_2_visibility, - upscale_before_face_fix, - ], - outputs=[ - result_images, - html_info_x, - html_info, - ] - ) - parameters_copypaste.add_paste_fields("extras", extras_image, None) - - extras_image.change( - fn=postprocessing.clear_cache, - inputs=[], outputs=[] - ) + ui_postprocessing.create_ui() with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Row().style(equal_height=False): diff --git a/modules/ui_common.py b/modules/ui_common.py new file mode 100644 index 00000000..8ce75b8c --- /dev/null +++ b/modules/ui_common.py @@ -0,0 +1,202 @@ +import json +import html +import os +import platform +import sys + +import gradio as gr +import scipy as sp + +from modules import call_queue, shared +from modules.generation_parameters_copypaste import image_from_url_text +import modules.images + +folder_symbol = '\U0001f4c2' # 📂 + + +def update_generation_info(generation_info, html_info, img_index): + try: + generation_info = json.loads(generation_info) + if img_index < 0 or img_index >= len(generation_info["infotexts"]): + return html_info, gr.update() + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() + except Exception: + pass + # if the json parse or anything else fails, just return the old html_info + return html_info, gr.update() + + +def plaintext_to_html(text): + text = "

      " + "
      \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

      " + return text + + +def save_files(js_data, images, do_make_zip, index): + import csv + filenames = [] + fullfns = [] + + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it + class MyObject: + def __init__(self, d=None): + if d is not None: + for key, value in d.items(): + setattr(self, key, value) + + data = json.loads(js_data) + + p = MyObject(data) + path = shared.opts.outdir_save + save_to_dirs = shared.opts.use_save_to_dirs_for_ui + extension: str = shared.opts.samples_format + start_index = 0 + + if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only + + images = [images[index]] + start_index = index + + os.makedirs(shared.opts.outdir_save, exist_ok=True) + + with open(os.path.join(shared.opts.outdir_save, "log.csv"), "a", encoding="utf8", newline='') as file: + at_start = file.tell() == 0 + writer = csv.writer(file) + if at_start: + writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) + + for image_index, filedata in enumerate(images, start_index): + image = image_from_url_text(filedata) + + is_grid = image_index < p.index_of_first_image + i = 0 if is_grid else (image_index - p.index_of_first_image) + + fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) + + filename = os.path.relpath(fullfn, path) + filenames.append(filename) + fullfns.append(fullfn) + if txt_fullfn: + filenames.append(os.path.basename(txt_fullfn)) + fullfns.append(txt_fullfn) + + writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) + + # Make Zip + if do_make_zip: + zip_filepath = os.path.join(path, "images.zip") + + from zipfile import ZipFile + with ZipFile(zip_filepath, "w") as zip_file: + for i in range(len(fullfns)): + with open(fullfns[i], mode="rb") as f: + zip_file.writestr(filenames[i], f.read()) + fullfns.insert(0, zip_filepath) + + return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") + + +def create_output_panel(tabname, outdir): + from modules import shared + import modules.generation_parameters_copypaste as parameters_copypaste + + def open_folder(f): + if not os.path.exists(f): + print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') + return + elif not os.path.isdir(f): + print(f""" +WARNING +An open_folder request was made with an argument that is not a folder. +This could be an error or a malicious attempt to run code on your computer. +Requested path was: {f} +""", file=sys.stderr) + return + + if not shared.cmd_opts.hide_ui_dir_config: + path = os.path.normpath(f) + if platform.system() == "Windows": + os.startfile(path) + elif platform.system() == "Darwin": + sp.Popen(["open", path]) + elif "microsoft-standard-WSL2" in platform.uname().release: + sp.Popen(["wsl-open", path]) + else: + sp.Popen(["xdg-open", path]) + + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) + + generation_info = None + with gr.Column(): + with gr.Row(elem_id=f"image_buttons_{tabname}"): + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') + + if tabname != "extras": + save = gr.Button('Save', elem_id=f'save_{tabname}') + save_zip = gr.Button('Zip', elem_id=f'save_zip_{tabname}') + + buttons = parameters_copypaste.create_buttons(["img2img", "inpaint", "extras"]) + + open_folder_button.click( + fn=lambda: open_folder(shared.opts.outdir_samples or outdir), + inputs=[], + outputs=[], + ) + + if tabname != "extras": + with gr.Row(): + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') + + with gr.Group(): + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') + if tabname == 'txt2img' or tabname == 'img2img': + generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") + generation_info_button.click( + fn=update_generation_info, + _js="function(x, y, z){ return [x, y, selected_gallery_index()] }", + inputs=[generation_info, html_info, html_info], + outputs=[html_info, html_info], + ) + + save.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, false, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ], + show_progress=False, + ) + + save_zip.click( + fn=call_queue.wrap_gradio_call(save_files), + _js="(x, y, z, w) => [x, y, true, selected_gallery_index()]", + inputs=[ + generation_info, + result_gallery, + html_info, + html_info, + ], + outputs=[ + download_files, + html_log, + ] + ) + + else: + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') + + parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) + return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log diff --git a/modules/ui_components.py b/modules/ui_components.py index 989cc87b..9aec3097 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -1,5 +1,3 @@ -import html - import gradio as gr @@ -50,7 +48,3 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" - -def plaintext_to_html(text): - text = "

      " + "
      \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

      " - return text diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py new file mode 100644 index 00000000..b418d955 --- /dev/null +++ b/modules/ui_postprocessing.py @@ -0,0 +1,57 @@ +import gradio as gr +from modules import scripts_postprocessing, scripts, shared, gfpgan_model, codeformer_model, ui_common, postprocessing, call_queue +import modules.generation_parameters_copypaste as parameters_copypaste + + +def create_ui(): + tab_index = gr.State(value=0) + + with gr.Row().style(equal_height=False, variant='compact'): + with gr.Column(variant='compact'): + with gr.Tabs(elem_id="mode_extras"): + with gr.TabItem('Single Image', elem_id="extras_single_tab") as tab_single: + extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") + + with gr.TabItem('Batch Process', elem_id="extras_batch_process_tab") as tab_batch: + image_batch = gr.File(label="Batch Process", file_count="multiple", interactive=True, type="file", elem_id="extras_image_batch") + + with gr.TabItem('Batch from Directory', elem_id="extras_batch_directory_tab") as tab_batch_dir: + extras_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, placeholder="A directory on the same machine where the server is running.", elem_id="extras_batch_input_dir") + extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.", elem_id="extras_batch_output_dir") + show_extras_results = gr.Checkbox(label='Show result images', value=True, elem_id="extras_show_extras_results") + + submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') + + script_inputs = scripts.scripts_postproc.setup_ui() + + with gr.Column(): + result_images, html_info_x, html_info, html_log = ui_common.create_output_panel("extras", shared.opts.outdir_extras_samples) + + tab_single.select(fn=lambda: 0, inputs=[], outputs=[tab_index]) + tab_batch.select(fn=lambda: 1, inputs=[], outputs=[tab_index]) + tab_batch_dir.select(fn=lambda: 2, inputs=[], outputs=[tab_index]) + + submit.click( + fn=call_queue.wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), + inputs=[ + tab_index, + extras_image, + image_batch, + extras_batch_input_dir, + extras_batch_output_dir, + show_extras_results, + *script_inputs + ], + outputs=[ + result_images, + html_info_x, + html_info, + ] + ) + + parameters_copypaste.add_paste_fields("extras", extras_image, None) + + extras_image.change( + fn=scripts.scripts_postproc.image_changed, + inputs=[], outputs=[] + ) diff --git a/scripts/postprocessing_codeformer.py b/scripts/postprocessing_codeformer.py new file mode 100644 index 00000000..a7d80d40 --- /dev/null +++ b/scripts/postprocessing_codeformer.py @@ -0,0 +1,36 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, codeformer_model +import gradio as gr + +from modules.ui_components import FormRow + + +class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): + name = "CodeFormer" + order = 3000 + + def ui(self): + with FormRow(): + codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, elem_id="extras_codeformer_visibility") + codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") + + return { + "codeformer_visibility": codeformer_visibility, + "codeformer_weight": codeformer_weight, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, codeformer_visibility, codeformer_weight): + if codeformer_visibility == 0: + return + + restored_img = codeformer_model.codeformer.restore(np.array(pp.image, dtype=np.uint8), w=codeformer_weight) + res = Image.fromarray(restored_img) + + if codeformer_visibility < 1.0: + res = Image.blend(pp.image, res, codeformer_visibility) + + pp.image = res + pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) + pp.info["CodeFormer weight"] = round(codeformer_weight, 3) diff --git a/scripts/postprocessing_gfpgan.py b/scripts/postprocessing_gfpgan.py new file mode 100644 index 00000000..d854f3f7 --- /dev/null +++ b/scripts/postprocessing_gfpgan.py @@ -0,0 +1,33 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, gfpgan_model +import gradio as gr + +from modules.ui_components import FormRow + + +class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): + name = "GFPGAN" + order = 2000 + + def ui(self): + with FormRow(): + gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN visibility", value=0, elem_id="extras_gfpgan_visibility") + + return { + "gfpgan_visibility": gfpgan_visibility, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, gfpgan_visibility): + if gfpgan_visibility == 0: + return + + restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image, dtype=np.uint8)) + res = Image.fromarray(restored_img) + + if gfpgan_visibility < 1.0: + res = Image.blend(pp.image, res, gfpgan_visibility) + + pp.image = res + pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py new file mode 100644 index 00000000..095d29b2 --- /dev/null +++ b/scripts/postprocessing_upscale.py @@ -0,0 +1,106 @@ +from PIL import Image +import numpy as np + +from modules import scripts_postprocessing, shared +import gradio as gr + +from modules.ui_components import FormRow + + +upscale_cache = {} + + +class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): + name = "Upscale" + order = 1000 + + def ui(self): + selected_tab = gr.State(value=0) + + with gr.Tabs(elem_id="extras_resize_mode"): + with gr.TabItem('Scale by', elem_id="extras_scale_by_tab") as tab_scale_by: + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4, elem_id="extras_upscaling_resize") + + with gr.TabItem('Scale to', elem_id="extras_scale_to_tab") as tab_scale_to: + with FormRow(): + upscaling_resize_w = gr.Number(label="Width", value=512, precision=0, elem_id="extras_upscaling_resize_w") + upscaling_resize_h = gr.Number(label="Height", value=512, precision=0, elem_id="extras_upscaling_resize_h") + upscaling_crop = gr.Checkbox(label='Crop to fit', value=True, elem_id="extras_upscaling_crop") + + with FormRow(): + extras_upscaler_1 = gr.Dropdown(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + + with FormRow(): + extras_upscaler_2 = gr.Dropdown(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=0.0, elem_id="extras_upscaler_2_visibility") + + tab_scale_by.select(fn=lambda: 0, inputs=[], outputs=[selected_tab]) + tab_scale_to.select(fn=lambda: 1, inputs=[], outputs=[selected_tab]) + + return { + "upscale_mode": selected_tab, + "upscale_by": upscaling_resize, + "upscale_to_width": upscaling_resize_w, + "upscale_to_height": upscaling_resize_h, + "upscale_crop": upscaling_crop, + "upscaler_1_name": extras_upscaler_1, + "upscaler_2_name": extras_upscaler_2, + "upscaler_2_visibility": extras_upscaler_2_visibility, + } + + def upscale(self, image, info, upscaler, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop): + if upscale_mode == 1: + upscale_by = max(upscale_to_width/image.width, upscale_to_height/image.height) + info["Postprocess upscale to"] = f"{upscale_to_width}x{upscale_to_height}" + else: + info["Postprocess upscale by"] = upscale_by + + cache_key = (hash(np.array(image.getdata()).tobytes()), upscaler.name, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + cached_image = upscale_cache.pop(cache_key, None) + + if cached_image is not None: + image = cached_image + else: + image = upscaler.scaler.upscale(image, upscale_by, upscaler.data_path) + + upscale_cache[cache_key] = image + if len(upscale_cache) > shared.opts.upscaling_max_images_in_cache: + upscale_cache.pop(next(iter(upscale_cache), None), None) + + if upscale_mode == 1 and upscale_crop: + cropped = Image.new("RGB", (upscale_to_width, upscale_to_height)) + cropped.paste(image, box=(upscale_to_width // 2 - image.width // 2, upscale_to_height // 2 - image.height // 2)) + image = cropped + info["Postprocess crop to"] = f"{image.width}x{image.height}" + + return image + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_mode=1, upscale_by=2.0, upscale_to_width=None, upscale_to_height=None, upscale_crop=False, upscaler_1_name=None, upscaler_2_name=None, upscaler_2_visibility=0.0): + if upscaler_1_name == "None": + upscaler_1_name = None + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_1_name]), None) + assert upscaler1 or (upscaler_1_name is None), f'could not find upscaler named {upscaler_1_name}' + + if not upscaler1: + return + + if upscaler_2_name == "None": + upscaler_2_name = None + + upscaler2 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_2_name and x.name != "None"]), None) + assert upscaler2 or (upscaler_2_name is None), f'could not find upscaler named {upscaler_2_name}' + + upscaled_image = self.upscale(pp.image, pp.info, upscaler1, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + pp.info[f"Postprocess upscaler"] = upscaler1.name + + if upscaler2 and upscaler_2_visibility > 0: + second_upscale = self.upscale(pp.image, pp.info, upscaler2, upscale_mode, upscale_by, upscale_to_width, upscale_to_height, upscale_crop) + upscaled_image = Image.blend(upscaled_image, second_upscale, upscaler_2_visibility) + + pp.info[f"Postprocess upscaler 2"] = upscaler2.name + + pp.image = upscaled_image + + def image_changed(self): + upscale_cache.clear() -- cgit v1.2.3 From 669dbd9725b3a285503e093a75c0dfa332073d8a Mon Sep 17 00:00:00 2001 From: Shondoit Date: Mon, 23 Jan 2023 09:54:42 +0100 Subject: Fix dark mode Fixes #7048 Co-Authored-By: J.J. Tolton --- modules/ui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index eb4b7e6b..43bcb7e5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1942,11 +1942,11 @@ def reload_javascript(): if cmd_opts.theme is not None: inline += f"set_theme('{cmd_opts.theme}');" - head += f'\n' - for script in modules.scripts.list_scripts("javascript", ".js"): head += f'\n' + head += f'\n' + def template_response(*args, **kwargs): res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace(b'', f'{head}'.encode("utf8")) -- cgit v1.2.3 From fabdae089e476c66eba3b0562e4e1881891804b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 14:42:49 +0300 Subject: add missing import to previous commit --- modules/ui.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modules/ui.py b/modules/ui.py index 8cb8e613..6b5dfd61 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,6 +41,7 @@ from modules.sd_samplers import samplers, samplers_for_img2img from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text +import modules.extras warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) -- cgit v1.2.3 From 41265a026de699cc223ca5b76c69b4e8e74aa7c1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 14:50:20 +0300 Subject: third time's the charm --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index f04ddfc2..36123aa5 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ import torch import tqdm from modules import shared, images, sd_models, sd_vae -from modules.ui import plaintext_to_html +from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch -- cgit v1.2.3 From 194cbd065e4644e986889b78a5a949e075b610e8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 15:50:32 +0300 Subject: fix open directory button failing --- modules/ui.py | 1 - modules/ui_common.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 94d4a80a..85ae62c7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -5,7 +5,6 @@ import mimetypes import os import platform import random -import subprocess as sp import sys import tempfile import time diff --git a/modules/ui_common.py b/modules/ui_common.py index 8ce75b8c..9405ac1f 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -5,7 +5,7 @@ import platform import sys import gradio as gr -import scipy as sp +import subprocess as sp from modules import call_queue, shared from modules.generation_parameters_copypaste import image_from_url_text -- cgit v1.2.3 From 59146621e256269b85feb536edeb745da20daf68 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 16:40:20 +0300 Subject: better support for xformers flash attention on older versions of torch --- modules/errors.py | 12 +++++++++++ modules/sd_hijack_optimizations.py | 42 ++++++++++++++++---------------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/modules/errors.py b/modules/errors.py index a10e8708..f6b80dbb 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -24,6 +24,18 @@ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable """) +already_displayed = {} + + +def display_once(e: Exception, task): + if task in already_displayed: + return + + display(e, task) + + already_displayed[task] = 1 + + def run(code, task): try: code() diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 9967359b..74452709 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared +from modules import shared, errors from modules.hypernetworks import hypernetwork from .sub_quadratic_attention import efficient_dot_product_attention @@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ ) +def get_xformers_flash_attention_op(q, k, v): + if not shared.cmd_opts.xformers_flash_attention: + return None + + try: + flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp + fw, bw = flash_attention_op + if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): + return flash_attention_op + except Exception as e: + errors.display_once(e, "enabling flash attention") + + return None + + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) @@ -291,18 +306,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - if shared.cmd_opts.xformers_flash_attention: - op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp - fw, bw = op - if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)): - # print('xformers_attention_forward', q.shape, k.shape, v.shape) - # Flash Attention is not availabe for the input arguments. - # Fallback to default xFormers' backend. - op = None - else: - op = None - - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -377,17 +381,7 @@ def xformers_attnblock_forward(self, x): q = q.contiguous() k = k.contiguous() v = v.contiguous() - if shared.cmd_opts.xformers_flash_attention: - op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp - fw, bw = op - if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)): - # print('xformers_attnblock_forward', q.shape, k.shape, v.shape) - # Flash Attention is not availabe for the input arguments. - # Fallback to default xFormers' backend. - op = None - else: - op = None - out = xformers.ops.memory_efficient_attention(q, k, v, op=op) + out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out -- cgit v1.2.3 From 925dd09c91e7338aef72e4ec99d67b8b57280215 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 09:03:17 -0500 Subject: improve interrogate --- modules/interrogate.py | 29 +++++++++++++++++------------ modules/shared.py | 1 + 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 19938cbb..1d1ac572 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -20,6 +20,7 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") +category_types = ["artists", "flavors", "mediums", "movements"] def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") @@ -27,12 +28,8 @@ def download_default_clip_interrogate_categories(content_dir): tmpdir = content_dir + "_tmp" try: os.makedirs(tmpdir) - - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt")) - torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt")) - + for category_type in category_types: + torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt")) os.rename(tmpdir, content_dir) except Exception as e: @@ -51,12 +48,13 @@ class InterrogateModels: def __init__(self, content_dir): self.loaded_categories = None + self.selected_categories = [] self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") def categories(self): - if self.loaded_categories is not None: - return self.loaded_categories + if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories: + return self.loaded_categories self.loaded_categories = [] @@ -64,14 +62,19 @@ class InterrogateModels: download_default_clip_interrogate_categories(self.content_dir) if os.path.exists(self.content_dir): - for filename in os.listdir(self.content_dir): + self.selected_categories = shared.opts.interrogate_clip_categories + for category_type in category_types: + if 'all' not in self.selected_categories and category_type not in self.selected_categories: + continue + filename = os.path.join(self.content_dir, f"{category_type}.txt") + if not os.path.isfile(filename): + continue m = re_topn.search(filename) topn = 1 if m is None else int(m.group(1)) - - with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file: + with open(filename, "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.loaded_categories.append(Category(name=filename, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines)) return self.loaded_categories @@ -139,6 +142,8 @@ class InterrogateModels: def rank(self, image_features, text_array, top_count=1): import clip + devices.torch_gc() + if shared.opts.interrogate_clip_dict_limit != 0: text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] diff --git a/modules/shared.py b/modules/shared.py index a644c0be..63b236c5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -424,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), + "interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), -- cgit v1.2.3 From 7ff1ef77dd22f7b38612f91b389237a5dbef2474 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 17:17:31 +0300 Subject: add a message about new torch/xformers version and a way to upgrade by specifying a commandline flag --- launch.py | 3 ++- webui.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 51c322c0..e7a0b50c 100644 --- a/launch.py +++ b/launch.py @@ -208,6 +208,7 @@ def prepare_environment(): sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') + sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch') sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') sys.argv, skip_install = extract_arg(sys.argv, '--skip-install') @@ -219,7 +220,7 @@ def prepare_environment(): print(f"Python {sys.version}") print(f"Commit hash: {commit}") - if not is_installed("torch") or not is_installed("torchvision"): + if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") if not skip_torch_cuda_test: diff --git a/webui.py b/webui.py index 7cf5885e..bc2baeab 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,7 @@ import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware +from packaging import version from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion @@ -49,7 +50,32 @@ else: server_name = "0.0.0.0" if cmd_opts.listen else None +def check_versions(): + expected_torch_version = "1.13.1" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + errors.print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded. + """.strip()) + + expected_xformers_version = "0.0.16rc425" + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + errors.print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + """.strip()) + + def initialize(): + check_versions() + extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) -- cgit v1.2.3 From e8c3d03f7d9966b81458944efb25666b2143153f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 17:59:58 +0300 Subject: a possible fix for broken image upscaling --- modules/postprocessing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 8514fea7..09d8e605 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -67,7 +67,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, pp.image.info["postprocessing"] = infotext if save_output: - images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) + images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) if extras_mode != 2 or show_extras_results: outputs.append(pp.image) -- cgit v1.2.3 From 6e1b296baf7a2cdc0ee747225f1704bd2d45c118 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 10:10:59 -0500 Subject: api-image-format --- modules/api/api.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 5d60fc0a..b1dd14cc 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -22,6 +22,8 @@ from modules.sd_models import checkpoints_list, find_checkpoint_config from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List +import piexif +import piexif.helper def upscaler_to_index(name: str): try: @@ -56,18 +58,30 @@ def decode_base64_to_image(encoding): def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: - # Copy any text-only metadata - use_metadata = False - metadata = PngImagePlugin.PngInfo() - for key, value in image.info.items(): - if isinstance(key, str) and isinstance(value, str): - metadata.add_text(key, value) - use_metadata = True + if opts.samples_format.lower() == 'png': + use_metadata = False + metadata = PngImagePlugin.PngInfo() + for key, value in image.info.items(): + if isinstance(key, str) and isinstance(value, str): + metadata.add_text(key, value) + use_metadata = True + image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) + + elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): + parameters = image.info.get('parameters', None) + exif_bytes = piexif.dump({ + "Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } + }) + if opts.samples_format.lower() in ("jpg", "jpeg"): + image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) + else: + image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) + + else: + raise HTTPException(status_code=500, detail="Invalid image format") - image.save( - output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) - ) bytes_data = output_bytes.getvalue() + return base64.b64encode(bytes_data) def api_middleware(app: FastAPI): -- cgit v1.2.3 From e407d1af897a7896d8c81e32dc86e7eb753ce207 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 18:12:51 +0300 Subject: add support for loras trained on kohya's scripts 0.4.0 (alphas) --- extensions-builtin/Lora/lora.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index da1797dc..220e64ff 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -92,6 +92,15 @@ def load_lora(name, filename): keys_failed_to_match.append(key_diffusers) continue + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "alpha": + lora_module.alpha = weight.item() + continue + if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: @@ -104,17 +113,12 @@ def load_lora(name, filename): module.to(device=devices.device, dtype=devices.dtype) - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - if lora_key == "lora_up.weight": lora_module.up = module elif lora_key == "lora_down.weight": lora_module.down = module else: - assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha' if len(keys_failed_to_match) > 0: print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") @@ -161,7 +165,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier + res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1] return res -- cgit v1.2.3 From dbcb6fac77f642e30d7b00b76cb7164a26dd4b94 Mon Sep 17 00:00:00 2001 From: Guillermo Moreno Date: Mon, 23 Jan 2023 12:14:01 -0300 Subject: feat(extra-networks): replace icon background with border --- html/image-update.svg | 6 +++++- style.css | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/html/image-update.svg b/html/image-update.svg index 525e4fc5..3abf12df 100644 --- a/html/image-update.svg +++ b/html/image-update.svg @@ -1,3 +1,7 @@ - + + + + + diff --git a/style.css b/style.css index ca0a172b..9fb00e49 100644 --- a/style.css +++ b/style.css @@ -837,7 +837,6 @@ footer { display: none; font-size: 0; text-align: -9999; - background-color: #fff; } .extra-network-thumbs .actions .name { -- cgit v1.2.3 From c6f20f72629f3c417f10db2289d131441c6832f5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 18:52:55 +0300 Subject: make loras before 0.4.0 ALSO work --- extensions-builtin/Lora/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 220e64ff..137e58f7 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -57,6 +57,7 @@ class LoraUpDownModule: def __init__(self): self.up = None self.down = None + self.alpha = None def assign_lora_names_to_compvis_modules(sd_model): @@ -165,7 +166,7 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier * module.alpha / module.up.weight.shape[1] + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res -- cgit v1.2.3 From 04a561c11c9bf9a00d7f9b50ca3f7962aa59ba6e Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 12:29:23 -0500 Subject: add option to skip interrogate categories --- modules/interrogate.py | 32 ++++++++++++++++++-------------- modules/shared.py | 2 +- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 1d1ac572..c252b148 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -2,6 +2,7 @@ import os import sys import traceback from collections import namedtuple +from pathlib import Path import re import torch @@ -20,12 +21,16 @@ Category = namedtuple("Category", ["name", "topn", "items"]) re_topn = re.compile(r"\.top(\d+)\.") -category_types = ["artists", "flavors", "mediums", "movements"] +def category_types(): + return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')] + def download_default_clip_interrogate_categories(content_dir): print("Downloading CLIP categories...") tmpdir = content_dir + "_tmp" + category_types = ["artists", "flavors", "mediums", "movements"] + try: os.makedirs(tmpdir) for category_type in category_types: @@ -48,33 +53,32 @@ class InterrogateModels: def __init__(self, content_dir): self.loaded_categories = None - self.selected_categories = [] + self.skip_categories = [] self.content_dir = content_dir self.running_on_cpu = devices.device_interrogate == torch.device("cpu") def categories(self): - if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories: + if not os.path.exists(self.content_dir): + download_default_clip_interrogate_categories(self.content_dir) + + if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories: return self.loaded_categories self.loaded_categories = [] - if not os.path.exists(self.content_dir): - download_default_clip_interrogate_categories(self.content_dir) - if os.path.exists(self.content_dir): - self.selected_categories = shared.opts.interrogate_clip_categories - for category_type in category_types: - if 'all' not in self.selected_categories and category_type not in self.selected_categories: - continue - filename = os.path.join(self.content_dir, f"{category_type}.txt") - if not os.path.isfile(filename): + self.skip_categories = shared.opts.interrogate_clip_skip_categories + category_types = [] + for filename in Path(self.content_dir).glob('*.txt'): + category_types.append(filename.stem) + if filename.stem in self.skip_categories: continue - m = re_topn.search(filename) + m = re_topn.search(filename.stem) topn = 1 if m is None else int(m.group(1)) with open(filename, "r", encoding="utf8") as file: lines = [x.strip() for x in file.readlines()] - self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines)) + self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines)) return self.loaded_categories diff --git a/modules/shared.py b/modules/shared.py index d7a18f6a..5f713bee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -424,7 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), - "interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}), + "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), -- cgit v1.2.3 From 865af20d8a4a823df3c950f5c9c9092a541bc57a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 21:28:59 +0300 Subject: suppress A matching Triton is not available message you can all now stop worrying about it --- webui.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/webui.py b/webui.py index bc2baeab..e1565a8d 100644 --- a/webui.py +++ b/webui.py @@ -1,6 +1,5 @@ import os import sys -import threading import time import importlib import signal @@ -10,6 +9,9 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from packaging import version +import logging +logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call -- cgit v1.2.3 From 7b1c7ba87b14da9960d0347269421233f4cb5838 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 23:11:34 +0300 Subject: add support for apostrophe in extra network names --- html/extra-networks-card.html | 4 ++-- modules/ui_extra_networks.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 1bdf1d27..aa9fca87 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,8 +1,8 @@ -
      +
      {name} diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 2ddac3d8..8b4f97f8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -3,6 +3,7 @@ import os.path from modules import shared import gradio as gr import json +import html from modules.generation_parameters_copypaste import image_from_url_text @@ -54,12 +55,13 @@ class ExtraNetworksPage: preview = item.get("preview", None) args = { - "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', + "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '', "prompt": item["prompt"], "tabname": json.dumps(tabname), "local_preview": json.dumps(item["local_preview"]), "name": item["name"], - "allow_negative_prompt": "true" if self.allow_negative_prompt else "false", + "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"', + "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"', } return self.card_page.format(**args) -- cgit v1.2.3 From 5c1cb9263f980641007088a37360fcab01761d37 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 00:24:17 +0300 Subject: fix BLIP failing to import depending on configuration --- modules/interrogate.py | 3 ++- modules/paths.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index c252b148..236e6983 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -83,7 +83,8 @@ class InterrogateModels: return self.loaded_categories def load_blip_model(self): - import models.blip + with paths.Prioritize("BLIP"): + import models.blip files = modelloader.load_models( model_path=os.path.join(paths.models_path, "BLIP"), diff --git a/modules/paths.py b/modules/paths.py index 4dd03a35..20b3e4d8 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,3 +38,17 @@ for d, must_exist, what, options in path_dirs: else: sys.path.append(d) paths[what] = d + + +class Prioritize: + def __init__(self, name): + self.name = name + self.path = None + + def __enter__(self): + self.path = sys.path.copy() + sys.path = [paths[self.name]] + sys.path + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.path = self.path + self.path = None -- cgit v1.2.3 From 82a28bfe35928e244d4b51d72ce424aff5619b75 Mon Sep 17 00:00:00 2001 From: Mykeehu Date: Mon, 23 Jan 2023 22:36:27 +0100 Subject: Fix extra network thumbs label color Added white color for labels. --- style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/style.css b/style.css index b2677fa1..ec046f78 100644 --- a/style.css +++ b/style.css @@ -857,6 +857,7 @@ footer { white-space: nowrap; text-overflow: ellipsis; background: rgba(0,0,0,.5); + color: white; } .extra-network-thumbs .card:hover .actions .name { -- cgit v1.2.3 From 45e270dfc853216b2c413f915946f0f2842e57a4 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 23 Jan 2023 17:11:22 -0500 Subject: add image decod exception handling --- modules/api/api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index b1dd14cc..e6e31e41 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -53,7 +53,11 @@ def setUpscalers(req: dict): def decode_base64_to_image(encoding): if encoding.startswith("data:image/"): encoding = encoding.split(";")[1].split(",")[1] - return Image.open(BytesIO(base64.b64decode(encoding))) + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as err: + raise HTTPException(status_code=500, detail="Invalid encoded image") def encode_pil_to_base64(image): with io.BytesIO() as output_bytes: -- cgit v1.2.3 From f99352582084890b9167c1bf8699865bea0cef5f Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 23 Jan 2023 21:50:59 -0500 Subject: Make SwinIR interruptible --- extensions-builtin/SwinIR/scripts/swinir_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 9a74b253..3479760a 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -8,7 +8,7 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import cmd_opts, opts +from modules.shared import cmd_opts, opts, state from swinir_model_arch import SwinIR as net from swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -145,7 +145,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale): with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: + if state.interrupted: + break + for w_idx in w_idx_list: + if state.interrupted: + break + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) -- cgit v1.2.3 From 3c47b050367ee220dcfed7be7883878417735614 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 23 Jan 2023 22:00:27 -0500 Subject: Also make SwinIR skippable --- extensions-builtin/SwinIR/scripts/swinir_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 3479760a..e8783bca 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -145,11 +145,11 @@ def inference(img, model, tile, tile_overlap, window_size, scale): with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: - if state.interrupted: + if state.interrupted or state.skipped: break for w_idx in w_idx_list: - if state.interrupted: + if state.interrupted or state.skipped: break in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] -- cgit v1.2.3 From 078e16e4d33ccbd40ff3ecfbb57ffd33a2a16c47 Mon Sep 17 00:00:00 2001 From: acncagua Date: Tue, 24 Jan 2023 12:21:07 +0900 Subject: Set Linux xformers 0.0.16RC425 --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index e7a0b50c..82094fa0 100644 --- a/launch.py +++ b/launch.py @@ -245,7 +245,7 @@ def prepare_environment(): if not is_installed("xformers"): exit(0) elif platform.system() == "Linux": - run_pip("install xformers", "xformers") + run_pip("install xformers==0.0.16rc425", "xformers") if not is_installed("pyngrok") and ngrok: run_pip("install pyngrok", "ngrok") -- cgit v1.2.3 From f64af77adcd20fabe00e1e642512db9c6742ed23 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 23 Jan 2023 22:49:20 -0500 Subject: Fix different first gen with Approx NN previews The loading of the model for approx nn live previews can change the internal state of PyTorch, resulting in a different image. This can be avoided by preloading the approx nn model in advance. --- modules/processing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index bc541e2f..3bd590ba 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -568,6 +568,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) + if shared.opts.live_previews_enable and sd_samplers.approximation_indexes.get(shared.opts.show_progress_type, 0) == 1: + # preload approx nn model before sampling for a more deterministic result + sd_vae_approx.model() + if not p.disable_extra_networks: extra_networks.activate(p, extra_network_data) -- cgit v1.2.3 From 42a70d74771e8920f658e741679768ed145dd76a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 10:05:45 +0300 Subject: repair sdapi/v1/upscalers returning bogus results --- modules/api/api.py | 16 +++++++++------- modules/api/models.py | 2 +- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index e6e31e41..da2a5daf 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -375,13 +375,15 @@ class Api: return [{"name": sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in sd_samplers.all_samplers] def get_upscalers(self): - upscalers = [] - - for upscaler in shared.sd_upscalers: - u = upscaler.scaler - upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url}) - - return upscalers + return [ + { + "name": upscaler.name, + "model_name": upscaler.scaler.model_name, + "model_path": upscaler.data_path, + "scale": upscaler.scale, + } + for upscaler in shared.sd_upscalers + ] def get_sd_models(self): return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] diff --git a/modules/api/models.py b/modules/api/models.py index 1eb1fcf1..e562ab54 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -219,7 +219,7 @@ class UpscalerItem(BaseModel): name: str = Field(title="Name") model_name: Optional[str] = Field(title="Model Name") model_path: Optional[str] = Field(title="Path") - model_url: Optional[str] = Field(title="URL") + scale: Optional[float] = Field(title="Scale") class SDModelItem(BaseModel): title: str = Field(title="Title") -- cgit v1.2.3 From 602a1864b05075ca4283986e6f5c7d5bce864e11 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 10:09:30 +0300 Subject: also return the removed field to sdapi/v1/upscalers because someone might have relied on it existing --- modules/api/api.py | 1 + modules/api/models.py | 1 + 2 files changed, 2 insertions(+) diff --git a/modules/api/api.py b/modules/api/api.py index da2a5daf..25c65e57 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -380,6 +380,7 @@ class Api: "name": upscaler.name, "model_name": upscaler.scaler.model_name, "model_path": upscaler.data_path, + "model_url": None, "scale": upscaler.scale, } for upscaler in shared.sd_upscalers diff --git a/modules/api/models.py b/modules/api/models.py index e562ab54..805bd8f7 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -219,6 +219,7 @@ class UpscalerItem(BaseModel): name: str = Field(title="Name") model_name: Optional[str] = Field(title="Model Name") model_path: Optional[str] = Field(title="Path") + model_url: Optional[str] = Field(title="URL") scale: Optional[float] = Field(title="Scale") class SDModelItem(BaseModel): -- cgit v1.2.3 From d30ac02f28bf5fa1ca5d4ba444180ba9e50b4860 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:21:32 -0500 Subject: renamed xy to xyz grid this is mostly just so git can detect it properly --- scripts/xy_grid.py | 498 ---------------------------------------------------- scripts/xyz_grid.py | 498 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 498 insertions(+), 498 deletions(-) delete mode 100644 scripts/xy_grid.py create mode 100644 scripts/xyz_grid.py diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py deleted file mode 100644 index 1a452355..00000000 --- a/scripts/xy_grid.py +++ /dev/null @@ -1,498 +0,0 @@ -from collections import namedtuple -from copy import copy -from itertools import permutations, chain -import random -import csv -from io import StringIO -from PIL import Image -import numpy as np - -import modules.scripts as scripts -import gradio as gr - -from modules import images, paths, sd_samplers, processing, sd_models, sd_vae -from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img -from modules.shared import opts, cmd_opts, state -import modules.shared as shared -import modules.sd_samplers -import modules.sd_models -import modules.sd_vae -import glob -import os -import re - -from modules.ui_components import ToolButton - -fill_values_symbol = "\U0001f4d2" # 📒 - - -def apply_field(field): - def fun(p, x, xs): - setattr(p, field, x) - - return fun - - -def apply_prompt(p, x, xs): - if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: - raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.") - - p.prompt = p.prompt.replace(xs[0], x) - p.negative_prompt = p.negative_prompt.replace(xs[0], x) - - -def apply_order(p, x, xs): - token_order = [] - - # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen - for token in x: - token_order.append((p.prompt.find(token), token)) - - token_order.sort(key=lambda t: t[0]) - - prompt_parts = [] - - # Split the prompt up, taking out the tokens - for _, token in token_order: - n = p.prompt.find(token) - prompt_parts.append(p.prompt[0:n]) - p.prompt = p.prompt[n + len(token):] - - # Rebuild the prompt with the tokens in the order we want - prompt_tmp = "" - for idx, part in enumerate(prompt_parts): - prompt_tmp += part - prompt_tmp += x[idx] - p.prompt = prompt_tmp + p.prompt - - -def apply_sampler(p, x, xs): - sampler_name = sd_samplers.samplers_map.get(x.lower(), None) - if sampler_name is None: - raise RuntimeError(f"Unknown sampler: {x}") - - p.sampler_name = sampler_name - - -def confirm_samplers(p, xs): - for x in xs: - if x.lower() not in sd_samplers.samplers_map: - raise RuntimeError(f"Unknown sampler: {x}") - - -def apply_checkpoint(p, x, xs): - info = modules.sd_models.get_closet_checkpoint_match(x) - if info is None: - raise RuntimeError(f"Unknown checkpoint: {x}") - modules.sd_models.reload_model_weights(shared.sd_model, info) - - -def confirm_checkpoints(p, xs): - for x in xs: - if modules.sd_models.get_closet_checkpoint_match(x) is None: - raise RuntimeError(f"Unknown checkpoint: {x}") - - -def apply_clip_skip(p, x, xs): - opts.data["CLIP_stop_at_last_layers"] = x - - -def apply_upscale_latent_space(p, x, xs): - if x.lower().strip() != '0': - opts.data["use_scale_latent_for_hires_fix"] = True - else: - opts.data["use_scale_latent_for_hires_fix"] = False - - -def find_vae(name: str): - if name.lower() in ['auto', 'automatic']: - return modules.sd_vae.unspecified - if name.lower() == 'none': - return None - else: - choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] - if len(choices) == 0: - print(f"No VAE found for {name}; using automatic") - return modules.sd_vae.unspecified - else: - return modules.sd_vae.vae_dict[choices[0]] - - -def apply_vae(p, x, xs): - modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) - - -def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): - p.styles = x.split(',') - - -def format_value_add_label(p, opt, x): - if type(x) == float: - x = round(x, 8) - - return f"{opt.label}: {x}" - - -def format_value(p, opt, x): - if type(x) == float: - x = round(x, 8) - return x - - -def format_value_join_list(p, opt, x): - return ", ".join(x) - - -def do_nothing(p, x, xs): - pass - - -def format_nothing(p, opt, x): - return "" - - -def str_permutations(x): - """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" - return x - - -class AxisOption: - def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): - self.label = label - self.type = type - self.apply = apply - self.format_value = format_value - self.confirm = confirm - self.cost = cost - self.choices = choices - - -class AxisOptionImg2Img(AxisOption): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_img2img = True - -class AxisOptionTxt2Img(AxisOption): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_img2img = False - - -axis_options = [ - AxisOption("Nothing", str, do_nothing, format_value=format_nothing), - AxisOption("Seed", int, apply_field("seed")), - AxisOption("Var. seed", int, apply_field("subseed")), - AxisOption("Var. strength", float, apply_field("subseed_strength")), - AxisOption("Steps", int, apply_field("steps")), - AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), - AxisOption("CFG Scale", float, apply_field("cfg_scale")), - AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), - AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), - AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), - AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), - AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), - AxisOption("Sigma Churn", float, apply_field("s_churn")), - AxisOption("Sigma min", float, apply_field("s_tmin")), - AxisOption("Sigma max", float, apply_field("s_tmax")), - AxisOption("Sigma noise", float, apply_field("s_noise")), - AxisOption("Eta", float, apply_field("eta")), - AxisOption("Clip skip", int, apply_clip_skip), - AxisOption("Denoising", float, apply_field("denoising_strength")), - AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), - AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), - AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), - AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), -] - - -def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): - ver_texts = [[images.GridAnnotation(y)] for y in y_labels] - hor_texts = [[images.GridAnnotation(x)] for x in x_labels] - - # Temporary list of all the images that are generated to be populated into the grid. - # Will be filled with empty images for any individual step that fails to process properly - image_cache = [None] * (len(xs) * len(ys)) - - processed_result = None - cell_mode = "P" - cell_size = (1, 1) - - state.job_count = len(xs) * len(ys) * p.n_iter - - def process_cell(x, y, ix, iy): - nonlocal image_cache, processed_result, cell_mode, cell_size - - state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" - - processed: Processed = cell(x, y) - - try: - # this dereference will throw an exception if the image was not processed - # (this happens in cases such as if the user stops the process from the UI) - processed_image = processed.images[0] - - if processed_result is None: - # Use our first valid processed result as a template container to hold our full results - processed_result = copy(processed) - cell_mode = processed_image.mode - cell_size = processed_image.size - processed_result.images = [Image.new(cell_mode, cell_size)] - - image_cache[ix + iy * len(xs)] = processed_image - if include_lone_images: - processed_result.images.append(processed_image) - processed_result.all_prompts.append(processed.prompt) - processed_result.all_seeds.append(processed.seed) - processed_result.infotexts.append(processed.infotexts[0]) - except: - image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size) - - if swap_axes_processing_order: - for ix, x in enumerate(xs): - for iy, y in enumerate(ys): - process_cell(x, y, ix, iy) - else: - for iy, y in enumerate(ys): - for ix, x in enumerate(xs): - process_cell(x, y, ix, iy) - - if not processed_result: - print("Unexpected error: draw_xy_grid failed to return even a single processed image") - return Processed(p, []) - - grid = images.image_grid(image_cache, rows=len(ys)) - if draw_legend: - grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) - - processed_result.images[0] = grid - - return processed_result - - -class SharedSettingsStackHelper(object): - def __enter__(self): - self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers - self.vae = opts.sd_vae - - def __exit__(self, exc_type, exc_value, tb): - opts.data["sd_vae"] = self.vae - modules.sd_models.reload_model_weights() - modules.sd_vae.reload_vae_weights() - - opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers - - -re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") - -re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") -re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") - - -class Script(scripts.Script): - def title(self): - return "X/Y plot" - - def ui(self, is_img2img): - self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] - - with gr.Row(): - with gr.Column(scale=19): - with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) - x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) - fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) - - with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) - y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) - - with gr.Row(variant="compact", elem_id="axis_options"): - draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) - no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) - swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") - - def swap_axes(x_type, x_values, y_type, y_values): - return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values - - swap_args = [x_type, x_values, y_type, y_values] - swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) - - def fill(x_type): - axis = self.current_axis_options[x_type] - return ", ".join(axis.choices()) if axis.choices else gr.update() - - fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) - fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) - - def select_axis(x_type): - return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) - - x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) - y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) - - return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] - - def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): - if not no_fixed_seeds: - modules.processing.fix_seed(p) - - if not opts.return_grid: - p.batch_size = 1 - - def process_axis(opt, vals): - if opt.label == 'Nothing': - return [0] - - valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))] - - if opt.type == int: - valslist_ext = [] - - for val in valslist: - m = re_range.fullmatch(val) - mc = re_range_count.fullmatch(val) - if m is not None: - start = int(m.group(1)) - end = int(m.group(2))+1 - step = int(m.group(3)) if m.group(3) is not None else 1 - - valslist_ext += list(range(start, end, step)) - elif mc is not None: - start = int(mc.group(1)) - end = int(mc.group(2)) - num = int(mc.group(3)) if mc.group(3) is not None else 1 - - valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] - else: - valslist_ext.append(val) - - valslist = valslist_ext - elif opt.type == float: - valslist_ext = [] - - for val in valslist: - m = re_range_float.fullmatch(val) - mc = re_range_count_float.fullmatch(val) - if m is not None: - start = float(m.group(1)) - end = float(m.group(2)) - step = float(m.group(3)) if m.group(3) is not None else 1 - - valslist_ext += np.arange(start, end + step, step).tolist() - elif mc is not None: - start = float(mc.group(1)) - end = float(mc.group(2)) - num = int(mc.group(3)) if mc.group(3) is not None else 1 - - valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() - else: - valslist_ext.append(val) - - valslist = valslist_ext - elif opt.type == str_permutations: - valslist = list(permutations(valslist)) - - valslist = [opt.type(x) for x in valslist] - - # Confirm options are valid before starting - if opt.confirm: - opt.confirm(p, valslist) - - return valslist - - x_opt = self.current_axis_options[x_type] - xs = process_axis(x_opt, x_values) - - y_opt = self.current_axis_options[y_type] - ys = process_axis(y_opt, y_values) - - def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label in ['Seed', 'Var. seed']: - return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] - else: - return axis_list - - if not no_fixed_seeds: - xs = fix_axis_seeds(x_opt, xs) - ys = fix_axis_seeds(y_opt, ys) - - if x_opt.label == 'Steps': - total_steps = sum(xs) * len(ys) - elif y_opt.label == 'Steps': - total_steps = sum(ys) * len(xs) - else: - total_steps = p.steps * len(xs) * len(ys) - - if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: - if x_opt.label == "Hires steps": - total_steps += sum(xs) * len(ys) - elif y_opt.label == "Hires steps": - total_steps += sum(ys) * len(xs) - elif p.hr_second_pass_steps: - total_steps += p.hr_second_pass_steps * len(xs) * len(ys) - else: - total_steps *= 2 - - total_steps *= p.n_iter - - image_cell_count = p.n_iter * p.batch_size - cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" - print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})") - shared.total_tqdm.updateTotal(total_steps) - - grid_infotext = [None] - - # If one of the axes is very slow to change between (like SD model - # checkpoint), then make sure it is in the outer iteration of the nested - # `for` loop. - swap_axes_processing_order = x_opt.cost > y_opt.cost - - def cell(x, y): - if shared.state.interrupted: - return Processed(p, [], p.seed, "") - - pc = copy(p) - x_opt.apply(pc, x, xs) - y_opt.apply(pc, y, ys) - - res = process_images(pc) - - if grid_infotext[0] is None: - pc.extra_generation_params = copy(pc.extra_generation_params) - - if x_opt.label != 'Nothing': - pc.extra_generation_params["X Type"] = x_opt.label - pc.extra_generation_params["X Values"] = x_values - if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: - pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) - - if y_opt.label != 'Nothing': - pc.extra_generation_params["Y Type"] = y_opt.label - pc.extra_generation_params["Y Values"] = y_values - if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: - pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) - - grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) - - return res - - with SharedSettingsStackHelper(): - processed = draw_xy_grid( - p, - xs=xs, - ys=ys, - x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], - y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], - cell=cell, - draw_legend=draw_legend, - include_lone_images=include_lone_images, - swap_axes_processing_order=swap_axes_processing_order - ) - - if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) - - return processed diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py new file mode 100644 index 00000000..1a452355 --- /dev/null +++ b/scripts/xyz_grid.py @@ -0,0 +1,498 @@ +from collections import namedtuple +from copy import copy +from itertools import permutations, chain +import random +import csv +from io import StringIO +from PIL import Image +import numpy as np + +import modules.scripts as scripts +import gradio as gr + +from modules import images, paths, sd_samplers, processing, sd_models, sd_vae +from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img +from modules.shared import opts, cmd_opts, state +import modules.shared as shared +import modules.sd_samplers +import modules.sd_models +import modules.sd_vae +import glob +import os +import re + +from modules.ui_components import ToolButton + +fill_values_symbol = "\U0001f4d2" # 📒 + + +def apply_field(field): + def fun(p, x, xs): + setattr(p, field, x) + + return fun + + +def apply_prompt(p, x, xs): + if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: + raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.") + + p.prompt = p.prompt.replace(xs[0], x) + p.negative_prompt = p.negative_prompt.replace(xs[0], x) + + +def apply_order(p, x, xs): + token_order = [] + + # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen + for token in x: + token_order.append((p.prompt.find(token), token)) + + token_order.sort(key=lambda t: t[0]) + + prompt_parts = [] + + # Split the prompt up, taking out the tokens + for _, token in token_order: + n = p.prompt.find(token) + prompt_parts.append(p.prompt[0:n]) + p.prompt = p.prompt[n + len(token):] + + # Rebuild the prompt with the tokens in the order we want + prompt_tmp = "" + for idx, part in enumerate(prompt_parts): + prompt_tmp += part + prompt_tmp += x[idx] + p.prompt = prompt_tmp + p.prompt + + +def apply_sampler(p, x, xs): + sampler_name = sd_samplers.samplers_map.get(x.lower(), None) + if sampler_name is None: + raise RuntimeError(f"Unknown sampler: {x}") + + p.sampler_name = sampler_name + + +def confirm_samplers(p, xs): + for x in xs: + if x.lower() not in sd_samplers.samplers_map: + raise RuntimeError(f"Unknown sampler: {x}") + + +def apply_checkpoint(p, x, xs): + info = modules.sd_models.get_closet_checkpoint_match(x) + if info is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + modules.sd_models.reload_model_weights(shared.sd_model, info) + + +def confirm_checkpoints(p, xs): + for x in xs: + if modules.sd_models.get_closet_checkpoint_match(x) is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + + +def apply_clip_skip(p, x, xs): + opts.data["CLIP_stop_at_last_layers"] = x + + +def apply_upscale_latent_space(p, x, xs): + if x.lower().strip() != '0': + opts.data["use_scale_latent_for_hires_fix"] = True + else: + opts.data["use_scale_latent_for_hires_fix"] = False + + +def find_vae(name: str): + if name.lower() in ['auto', 'automatic']: + return modules.sd_vae.unspecified + if name.lower() == 'none': + return None + else: + choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] + if len(choices) == 0: + print(f"No VAE found for {name}; using automatic") + return modules.sd_vae.unspecified + else: + return modules.sd_vae.vae_dict[choices[0]] + + +def apply_vae(p, x, xs): + modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) + + +def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): + p.styles = x.split(',') + + +def format_value_add_label(p, opt, x): + if type(x) == float: + x = round(x, 8) + + return f"{opt.label}: {x}" + + +def format_value(p, opt, x): + if type(x) == float: + x = round(x, 8) + return x + + +def format_value_join_list(p, opt, x): + return ", ".join(x) + + +def do_nothing(p, x, xs): + pass + + +def format_nothing(p, opt, x): + return "" + + +def str_permutations(x): + """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" + return x + + +class AxisOption: + def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): + self.label = label + self.type = type + self.apply = apply + self.format_value = format_value + self.confirm = confirm + self.cost = cost + self.choices = choices + + +class AxisOptionImg2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = True + +class AxisOptionTxt2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = False + + +axis_options = [ + AxisOption("Nothing", str, do_nothing, format_value=format_nothing), + AxisOption("Seed", int, apply_field("seed")), + AxisOption("Var. seed", int, apply_field("subseed")), + AxisOption("Var. strength", float, apply_field("subseed_strength")), + AxisOption("Steps", int, apply_field("steps")), + AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), + AxisOption("CFG Scale", float, apply_field("cfg_scale")), + AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), + AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), + AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), + AxisOption("Sigma Churn", float, apply_field("s_churn")), + AxisOption("Sigma min", float, apply_field("s_tmin")), + AxisOption("Sigma max", float, apply_field("s_tmax")), + AxisOption("Sigma noise", float, apply_field("s_noise")), + AxisOption("Eta", float, apply_field("eta")), + AxisOption("Clip skip", int, apply_clip_skip), + AxisOption("Denoising", float, apply_field("denoising_strength")), + AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), + AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), + AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), + AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), +] + + +def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): + ver_texts = [[images.GridAnnotation(y)] for y in y_labels] + hor_texts = [[images.GridAnnotation(x)] for x in x_labels] + + # Temporary list of all the images that are generated to be populated into the grid. + # Will be filled with empty images for any individual step that fails to process properly + image_cache = [None] * (len(xs) * len(ys)) + + processed_result = None + cell_mode = "P" + cell_size = (1, 1) + + state.job_count = len(xs) * len(ys) * p.n_iter + + def process_cell(x, y, ix, iy): + nonlocal image_cache, processed_result, cell_mode, cell_size + + state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" + + processed: Processed = cell(x, y) + + try: + # this dereference will throw an exception if the image was not processed + # (this happens in cases such as if the user stops the process from the UI) + processed_image = processed.images[0] + + if processed_result is None: + # Use our first valid processed result as a template container to hold our full results + processed_result = copy(processed) + cell_mode = processed_image.mode + cell_size = processed_image.size + processed_result.images = [Image.new(cell_mode, cell_size)] + + image_cache[ix + iy * len(xs)] = processed_image + if include_lone_images: + processed_result.images.append(processed_image) + processed_result.all_prompts.append(processed.prompt) + processed_result.all_seeds.append(processed.seed) + processed_result.infotexts.append(processed.infotexts[0]) + except: + image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size) + + if swap_axes_processing_order: + for ix, x in enumerate(xs): + for iy, y in enumerate(ys): + process_cell(x, y, ix, iy) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, ix, iy) + + if not processed_result: + print("Unexpected error: draw_xy_grid failed to return even a single processed image") + return Processed(p, []) + + grid = images.image_grid(image_cache, rows=len(ys)) + if draw_legend: + grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) + + processed_result.images[0] = grid + + return processed_result + + +class SharedSettingsStackHelper(object): + def __enter__(self): + self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers + self.vae = opts.sd_vae + + def __exit__(self, exc_type, exc_value, tb): + opts.data["sd_vae"] = self.vae + modules.sd_models.reload_model_weights() + modules.sd_vae.reload_vae_weights() + + opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers + + +re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") +re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") + +re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") +re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") + + +class Script(scripts.Script): + def title(self): + return "X/Y plot" + + def ui(self, is_img2img): + self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] + + with gr.Row(): + with gr.Column(scale=19): + with gr.Row(): + x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) + fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) + + with gr.Row(): + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) + + with gr.Row(variant="compact", elem_id="axis_options"): + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") + + def swap_axes(x_type, x_values, y_type, y_values): + return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values + + swap_args = [x_type, x_values, y_type, y_values] + swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) + + def fill(x_type): + axis = self.current_axis_options[x_type] + return ", ".join(axis.choices()) if axis.choices else gr.update() + + fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) + fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) + + def select_axis(x_type): + return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) + + x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) + y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) + + return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] + + def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): + if not no_fixed_seeds: + modules.processing.fix_seed(p) + + if not opts.return_grid: + p.batch_size = 1 + + def process_axis(opt, vals): + if opt.label == 'Nothing': + return [0] + + valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))] + + if opt.type == int: + valslist_ext = [] + + for val in valslist: + m = re_range.fullmatch(val) + mc = re_range_count.fullmatch(val) + if m is not None: + start = int(m.group(1)) + end = int(m.group(2))+1 + step = int(m.group(3)) if m.group(3) is not None else 1 + + valslist_ext += list(range(start, end, step)) + elif mc is not None: + start = int(mc.group(1)) + end = int(mc.group(2)) + num = int(mc.group(3)) if mc.group(3) is not None else 1 + + valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] + else: + valslist_ext.append(val) + + valslist = valslist_ext + elif opt.type == float: + valslist_ext = [] + + for val in valslist: + m = re_range_float.fullmatch(val) + mc = re_range_count_float.fullmatch(val) + if m is not None: + start = float(m.group(1)) + end = float(m.group(2)) + step = float(m.group(3)) if m.group(3) is not None else 1 + + valslist_ext += np.arange(start, end + step, step).tolist() + elif mc is not None: + start = float(mc.group(1)) + end = float(mc.group(2)) + num = int(mc.group(3)) if mc.group(3) is not None else 1 + + valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() + else: + valslist_ext.append(val) + + valslist = valslist_ext + elif opt.type == str_permutations: + valslist = list(permutations(valslist)) + + valslist = [opt.type(x) for x in valslist] + + # Confirm options are valid before starting + if opt.confirm: + opt.confirm(p, valslist) + + return valslist + + x_opt = self.current_axis_options[x_type] + xs = process_axis(x_opt, x_values) + + y_opt = self.current_axis_options[y_type] + ys = process_axis(y_opt, y_values) + + def fix_axis_seeds(axis_opt, axis_list): + if axis_opt.label in ['Seed', 'Var. seed']: + return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] + else: + return axis_list + + if not no_fixed_seeds: + xs = fix_axis_seeds(x_opt, xs) + ys = fix_axis_seeds(y_opt, ys) + + if x_opt.label == 'Steps': + total_steps = sum(xs) * len(ys) + elif y_opt.label == 'Steps': + total_steps = sum(ys) * len(xs) + else: + total_steps = p.steps * len(xs) * len(ys) + + if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: + if x_opt.label == "Hires steps": + total_steps += sum(xs) * len(ys) + elif y_opt.label == "Hires steps": + total_steps += sum(ys) * len(xs) + elif p.hr_second_pass_steps: + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) + else: + total_steps *= 2 + + total_steps *= p.n_iter + + image_cell_count = p.n_iter * p.batch_size + cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" + print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})") + shared.total_tqdm.updateTotal(total_steps) + + grid_infotext = [None] + + # If one of the axes is very slow to change between (like SD model + # checkpoint), then make sure it is in the outer iteration of the nested + # `for` loop. + swap_axes_processing_order = x_opt.cost > y_opt.cost + + def cell(x, y): + if shared.state.interrupted: + return Processed(p, [], p.seed, "") + + pc = copy(p) + x_opt.apply(pc, x, xs) + y_opt.apply(pc, y, ys) + + res = process_images(pc) + + if grid_infotext[0] is None: + pc.extra_generation_params = copy(pc.extra_generation_params) + + if x_opt.label != 'Nothing': + pc.extra_generation_params["X Type"] = x_opt.label + pc.extra_generation_params["X Values"] = x_values + if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) + + if y_opt.label != 'Nothing': + pc.extra_generation_params["Y Type"] = y_opt.label + pc.extra_generation_params["Y Values"] = y_values + if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) + + grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) + + return res + + with SharedSettingsStackHelper(): + processed = draw_xy_grid( + p, + xs=xs, + ys=ys, + x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], + y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], + cell=cell, + draw_legend=draw_legend, + include_lone_images=include_lone_images, + swap_axes_processing_order=swap_axes_processing_order + ) + + if opts.grid_save: + images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + + return processed -- cgit v1.2.3 From 9fc354e1303453bbd865cbede86da4c3273ed14f Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:22:40 -0500 Subject: implements most of xyz grid script --- scripts/xyz_grid.py | 114 +++++++++++++++++++++++++++++++++++----------------- 1 file changed, 78 insertions(+), 36 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 1a452355..494e8417 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -205,26 +205,30 @@ axis_options = [ ] -def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): +def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] + title_texts = [[images.GridAnnotation(z)] for z in z_labels] # Temporary list of all the images that are generated to be populated into the grid. # Will be filled with empty images for any individual step that fails to process properly - image_cache = [None] * (len(xs) * len(ys)) + image_cache = [None] * (len(xs) * len(ys) * len(zs)) processed_result = None cell_mode = "P" cell_size = (1, 1) - state.job_count = len(xs) * len(ys) * p.n_iter + state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter - def process_cell(x, y, ix, iy): + def process_cell(x, y, z, ix, iy, iz): nonlocal image_cache, processed_result, cell_mode, cell_size - state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" + def index(ix, iy, iz): + return ix + iy*len(xs) + iz*len(xs)*len(ys) - processed: Processed = cell(x, y) + state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}" + + processed: Processed = cell(x, y, z) try: # this dereference will throw an exception if the image was not processed @@ -238,33 +242,40 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ cell_size = processed_image.size processed_result.images = [Image.new(cell_mode, cell_size)] - image_cache[ix + iy * len(xs)] = processed_image + image_cache[index(ix, iy, iz)] = processed_image if include_lone_images: processed_result.images.append(processed_image) processed_result.all_prompts.append(processed.prompt) processed_result.all_seeds.append(processed.seed) processed_result.infotexts.append(processed.infotexts[0]) except: - image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size) + image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size) if swap_axes_processing_order: for ix, x in enumerate(xs): for iy, y in enumerate(ys): - process_cell(x, y, ix, iy) + for iy, y in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) else: for iy, y in enumerate(ys): for ix, x in enumerate(xs): - process_cell(x, y, ix, iy) + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) if not processed_result: - print("Unexpected error: draw_xy_grid failed to return even a single processed image") + print("Unexpected error: draw_xyz_grid failed to return even a single processed image") return Processed(p, []) - grid = images.image_grid(image_cache, rows=len(ys)) - if draw_legend: - grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) - - processed_result.images[0] = grid + for i, title_text in enumerate(title_texts): + start_index = i * len(xs) * len(ys) + end_index = start_index + len(xs) * len(ys) + grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys)) + if draw_legend: + grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) + if i == 0: # First position is a placeholder as mentioned above, so it can be directly replaced + processed_result.images[0] = grid + else: + processed_result.images.insert(i, grid) return processed_result @@ -291,7 +302,7 @@ re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+ class Script(scripts.Script): def title(self): - return "X/Y plot" + return "X/Y/Z plot" def ui(self, is_img2img): self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] @@ -301,24 +312,35 @@ class Script(scripts.Script): with gr.Row(): x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) - fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) + fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False) with gr.Row(): y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyzz_grid_fill_y_tool_button", visible=False) + + with gr.Row(): + z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) + z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values")) + fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) with gr.Row(variant="compact", elem_id="axis_options"): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) - swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") + swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") + swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") + swap_xz_axes_button = gr.Button(value="Swap X/Z axes", elem_id="xz_grid_swap_axes_button") - def swap_axes(x_type, x_values, y_type, y_values): - return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values + def swap_axes(axis1_type, axis1_values, axis2_type, axis2_values): + return self.current_axis_options[axis2_type].label, axis2_values, self.current_axis_options[axis1_type].label, axis1_values - swap_args = [x_type, x_values, y_type, y_values] - swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) + xy_swap_args = [x_type, x_values, y_type, y_values] + swap_xy_axes_button.click(swap_axes, inputs=xy_swap_args, outputs=xy_swap_args) + yz_swap_args = [y_type, y_values, z_type, z_values] + swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args) + xz_swap_args = [x_type, x_values, z_type, z_values] + swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args) def fill(x_type): axis = self.current_axis_options[x_type] @@ -326,16 +348,18 @@ class Script(scripts.Script): fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) + fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values]) def select_axis(x_type): return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) + z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button]) - return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] + return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds] - def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): + def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -409,6 +433,9 @@ class Script(scripts.Script): y_opt = self.current_axis_options[y_type] ys = process_axis(y_opt, y_values) + z_opt = self.current_axis_options[z_type] + zs = process_axis(z_opt, z_values) + def fix_axis_seeds(axis_opt, axis_list): if axis_opt.label in ['Seed', 'Var. seed']: return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] @@ -418,21 +445,26 @@ class Script(scripts.Script): if not no_fixed_seeds: xs = fix_axis_seeds(x_opt, xs) ys = fix_axis_seeds(y_opt, ys) + zs = fix_axis_seeds(z_opt, zs) if x_opt.label == 'Steps': - total_steps = sum(xs) * len(ys) + total_steps = sum(xs) * len(ys) * len(zs) elif y_opt.label == 'Steps': - total_steps = sum(ys) * len(xs) + total_steps = sum(ys) * len(xs) * len(zs) + elif z_opt.label == 'Steps': + total_steps = sum(zs) * len(xs) * len(ys) else: - total_steps = p.steps * len(xs) * len(ys) + total_steps = p.steps * len(xs) * len(ys) * len(zs) if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: if x_opt.label == "Hires steps": - total_steps += sum(xs) * len(ys) + total_steps += sum(xs) * len(ys) * len(zs) elif y_opt.label == "Hires steps": - total_steps += sum(ys) * len(xs) + total_steps += sum(ys) * len(xs) * len(zs) + elif z_opt.label == "Hires steps": + total_steps += sum(zs) * len(xs) * len(ys) elif p.hr_second_pass_steps: - total_steps += p.hr_second_pass_steps * len(xs) * len(ys) + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) * len(zs) else: total_steps *= 2 @@ -440,7 +472,8 @@ class Script(scripts.Script): image_cell_count = p.n_iter * p.batch_size cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" - print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})") + plural_s = 's' if len(zs) > 1 else '' + print(f"X/Y plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})") shared.total_tqdm.updateTotal(total_steps) grid_infotext = [None] @@ -450,13 +483,14 @@ class Script(scripts.Script): # `for` loop. swap_axes_processing_order = x_opt.cost > y_opt.cost - def cell(x, y): + def cell(x, y, z): if shared.state.interrupted: return Processed(p, [], p.seed, "") pc = copy(p) x_opt.apply(pc, x, xs) y_opt.apply(pc, y, ys) + z_opt.apply(pc, z, zs) res = process_images(pc) @@ -475,17 +509,25 @@ class Script(scripts.Script): if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) + if z_opt.label != 'Nothing': + pc.extra_generation_params["Z Type"] = z_opt.label + pc.extra_generation_params["Z Values"] = z_values + if z_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Z Values"] = ", ".join([str(z) for z in zs]) + grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) return res with SharedSettingsStackHelper(): - processed = draw_xy_grid( + processed = draw_xyz_grid( p, xs=xs, ys=ys, + zs=zs, x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], + z_labels=[y_opt.format_value(p, z_opt, z) for z in zs], cell=cell, draw_legend=draw_legend, include_lone_images=include_lone_images, @@ -493,6 +535,6 @@ class Script(scripts.Script): ) if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) return processed -- cgit v1.2.3 From e46bfa5a9e9b489ae925a9c23880e34fe8d9fffa Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:24:32 -0500 Subject: handling sub grids and merging into one --- modules/images.py | 2 +- scripts/xyz_grid.py | 29 ++++++++++++++++++----------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/modules/images.py b/modules/images.py index 3b1c5f34..0bc3d524 100644 --- a/modules/images.py +++ b/modules/images.py @@ -195,7 +195,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts] - pad_top = max(hor_text_heights) + line_spacing * 2 + pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2 result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white") result.paste(im, (pad_left, pad_top)) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 494e8417..a16653da 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -205,9 +205,9 @@ axis_options = [ ] -def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): - ver_texts = [[images.GridAnnotation(y)] for y in y_labels] +def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, swap_axes_processing_order): hor_texts = [[images.GridAnnotation(x)] for x in x_labels] + ver_texts = [[images.GridAnnotation(y)] for y in y_labels] title_texts = [[images.GridAnnotation(z)] for z in z_labels] # Temporary list of all the images that are generated to be populated into the grid. @@ -266,16 +266,21 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend print("Unexpected error: draw_xyz_grid failed to return even a single processed image") return Processed(p, []) - for i, title_text in enumerate(title_texts): + grids = [None] * len(zs) + for i in range(len(zs)): start_index = i * len(xs) * len(ys) end_index = start_index + len(xs) * len(ys) grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys)) if draw_legend: grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) - if i == 0: # First position is a placeholder as mentioned above, so it can be directly replaced - processed_result.images[0] = grid - else: - processed_result.images.insert(i, grid) + + grids[i] = grid + if include_sub_grids and len(zs) > 1: + processed_result.images.insert(i+1, grid) + + original_grid_size = grids[0].size + grids = images.image_grid(grids, rows=1) + processed_result.images[0] = images.draw_grid_annotations(grids, original_grid_size[0], original_grid_size[1], title_texts, [[images.GridAnnotation()]]) return processed_result @@ -326,7 +331,8 @@ class Script(scripts.Script): with gr.Row(variant="compact", elem_id="axis_options"): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) + include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) + include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button") @@ -357,9 +363,9 @@ class Script(scripts.Script): y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button]) - return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds] + return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds] - def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, no_fixed_seeds): + def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -527,10 +533,11 @@ class Script(scripts.Script): zs=zs, x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], - z_labels=[y_opt.format_value(p, z_opt, z) for z in zs], + z_labels=[z_opt.format_value(p, z_opt, z) for z in zs], cell=cell, draw_legend=draw_legend, include_lone_images=include_lone_images, + include_sub_grids=include_sub_grids, swap_axes_processing_order=swap_axes_processing_order ) -- cgit v1.2.3 From ec8774729e17f87a8ffa5a3c5328d12834cbb02a Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:53:35 -0500 Subject: swaps xyz axes internally if one costs more --- scripts/xyz_grid.py | 64 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index a16653da..828c2d12 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -205,7 +205,7 @@ axis_options = [ ] -def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, swap_axes_processing_order): +def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed): hor_texts = [[images.GridAnnotation(x)] for x in x_labels] ver_texts = [[images.GridAnnotation(y)] for y in y_labels] title_texts = [[images.GridAnnotation(z)] for z in z_labels] @@ -224,7 +224,7 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend nonlocal image_cache, processed_result, cell_mode, cell_size def index(ix, iy, iz): - return ix + iy*len(xs) + iz*len(xs)*len(ys) + return ix + iy * len(xs) + iz * len(xs) * len(ys) state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}" @@ -251,16 +251,36 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend except: image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size) - if swap_axes_processing_order: + if first_axes_processed == 'x': for ix, x in enumerate(xs): - for iy, y in enumerate(ys): - for iy, y in enumerate(zs): - process_cell(x, y, z, ix, iy, iz) - else: + if second_axes_processed == 'y': + for iy, y in enumerate(ys): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: + for iz, z in enumerate(zs): + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'y': for iy, y in enumerate(ys): - for ix, x in enumerate(xs): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iz, z in enumerate(zs): + process_cell(x, y, z, ix, iy, iz) + else: for iz, z in enumerate(zs): - process_cell(x, y, z, ix, iy, iz) + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) + elif first_axes_processed == 'z': + for iz, z in enumerate(zs): + if second_axes_processed == 'x': + for ix, x in enumerate(xs): + for iy, y in enumerate(ys): + process_cell(x, y, z, ix, iy, iz) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, z, ix, iy, iz) if not processed_result: print("Unexpected error: draw_xyz_grid failed to return even a single processed image") @@ -322,7 +342,7 @@ class Script(scripts.Script): with gr.Row(): y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyzz_grid_fill_y_tool_button", visible=False) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) with gr.Row(): z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) @@ -487,7 +507,26 @@ class Script(scripts.Script): # If one of the axes is very slow to change between (like SD model # checkpoint), then make sure it is in the outer iteration of the nested # `for` loop. - swap_axes_processing_order = x_opt.cost > y_opt.cost + first_axes_processed = 'x' + second_axes_processed = 'y' + if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost: + first_axes_processed = 'x' + if y_opt.cost > z_opt.cost: + second_axes_processed = 'y' + else: + second_axes_processed = 'z' + elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost: + first_axes_processed = 'y' + if x_opt.cost > z_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'z' + elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost: + first_axes_processed = 'z' + if x_opt.cost > y_opt.cost: + second_axes_processed = 'x' + else: + second_axes_processed = 'y' def cell(x, y, z): if shared.state.interrupted: @@ -538,7 +577,8 @@ class Script(scripts.Script): draw_legend=draw_legend, include_lone_images=include_lone_images, include_sub_grids=include_sub_grids, - swap_axes_processing_order=swap_axes_processing_order + first_axes_processed=first_axes_processed, + second_axes_processed=second_axes_processed ) if opts.grid_save: -- cgit v1.2.3 From dac45299dd57c6cb240424b93fd28a085605bd90 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 20:22:19 +0300 Subject: make git commands not fail for extensions when you have spaces in webui directory --- launch.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/launch.py b/launch.py index 82094fa0..6d523a34 100644 --- a/launch.py +++ b/launch.py @@ -108,18 +108,18 @@ def git_clone(url, dir, name, commithash=None): if commithash is None: return - current_hash = run(f'"{git}" -C {dir} rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() + current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip() if current_hash == commithash: return - run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") - run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") + run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") + run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") return run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") if commithash is not None: - run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") + run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") def version_check(commit): -- cgit v1.2.3 From 28189985e6f56dc725938a3f0e4d2462dad74bc5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 20:24:27 +0300 Subject: remove fairscale requirement, add fake fairscale to make BLIP not complain about it --- modules/interrogate.py | 11 +++++++++-- requirements.txt | 1 - requirements_versions.txt | 1 - 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 236e6983..9f063197 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -82,9 +82,16 @@ class InterrogateModels: return self.loaded_categories + def create_fake_fairscale(self): + class FakeFairscale: + def checkpoint_wrapper(self): + pass + + sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale + def load_blip_model(self): - with paths.Prioritize("BLIP"): - import models.blip + create_fake_fairscale() + import models.blip files = modelloader.load_models( model_path=os.path.join(paths.models_path, "BLIP"), diff --git a/requirements.txt b/requirements.txt index ef5e3472..a4be1ec3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ blendmodes accelerate basicsr -fairscale==0.4.4 fonts font-roboto gfpgan diff --git a/requirements_versions.txt b/requirements_versions.txt index f97ad765..135908be 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -14,7 +14,6 @@ scikit-image==0.19.2 fonts font-roboto timm==0.6.7 -fairscale==0.4.9 piexif==1.1.3 einops==0.4.1 jsonmerge==1.8.0 -- cgit v1.2.3 From 5228ec8bdada50a8d614573e980193ca89192361 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 20:30:43 +0300 Subject: remove fairscale requirement, add fake fairscale to make BLIP not complain about it mk2 --- modules/interrogate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/interrogate.py b/modules/interrogate.py index 9f063197..c72ff694 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -90,7 +90,7 @@ class InterrogateModels: sys.modules["fairscale.nn.checkpoint.checkpoint_activations"] = FakeFairscale def load_blip_model(self): - create_fake_fairscale() + self.create_fake_fairscale() import models.blip files = modelloader.load_models( -- cgit v1.2.3 From 93fad28a979727f9b1331dbdc447598824057cdc Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 21:13:05 +0300 Subject: print progress when installing torch add PIP_INSTALLER_LOCATION env var to install pip if it's not installed remove accidental call to accelerate when venv is disabled add another env var to skip venv - SKIP_VENV --- launch.py | 20 +++++++++++++++++--- webui.bat | 8 +++++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/launch.py b/launch.py index 6d523a34..f578c1c7 100644 --- a/launch.py +++ b/launch.py @@ -48,10 +48,19 @@ def extract_opt(args, name): return args, is_present, opt -def run(command, desc=None, errdesc=None, custom_env=None): +def run(command, desc=None, errdesc=None, custom_env=None, live=False): if desc is not None: print(desc) + if live: + result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env) + if result.returncode != 0: + raise RuntimeError(f"""{errdesc or 'Error running command'}. +Command: {command} +Error code: {result.returncode}""") + + return "" + result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env) if result.returncode != 0: @@ -179,6 +188,8 @@ def run_extensions_installers(settings_file): def prepare_environment(): global skip_install + pip_installer_location = os.environ.get('PIP_INSTALLER_LOCATION', None) + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -219,9 +230,12 @@ def prepare_environment(): print(f"Python {sys.version}") print(f"Commit hash: {commit}") - + + if pip_installer_location is not None and not is_installed("pip"): + run(f'"{python}" "{pip_installer_location}"', "Installing pip", "Couldn't install pip") + if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): - run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") + run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) if not skip_torch_cuda_test: run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") diff --git a/webui.bat b/webui.bat index 3165b94d..0d6865c9 100644 --- a/webui.bat +++ b/webui.bat @@ -3,6 +3,7 @@ if not defined PYTHON (set PYTHON=python) if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") + set ERROR_REPORTING=FALSE mkdir tmp 2>NUL @@ -14,6 +15,7 @@ goto :show_stdout_stderr :start_venv if ["%VENV_DIR%"] == ["-"] goto :skip_venv +if ["%SKIP_VENV%"] == ["1"] goto :skip_venv dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt if %ERRORLEVEL% == 0 goto :activate_venv @@ -28,13 +30,13 @@ goto :show_stdout_stderr :activate_venv set PYTHON="%VENV_DIR%\Scripts\Python.exe" echo venv %PYTHON% -if [%ACCELERATE%] == ["True"] goto :accelerate -goto :launch :skip_venv +if [%ACCELERATE%] == ["True"] goto :accelerate +goto :launch :accelerate -echo "Checking for accelerate" +echo Checking for accelerate set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe" if EXIST %ACCELERATE% goto :accelerate_launch -- cgit v1.2.3 From bef193189500884c2b20605290ac8bef8251a788 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 24 Jan 2023 23:50:04 +0300 Subject: add fastapi to requirements --- requirements_versions.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements_versions.txt b/requirements_versions.txt index 135908be..1c328d44 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -4,6 +4,7 @@ accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 gradio==3.16.2 +fastapi==0.82.0 numpy==1.23.3 Pillow==9.4.0 realesrgan==0.3.0 -- cgit v1.2.3 From 48a15821de768fea76e66f26df83df3fddf18f4b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 00:49:16 +0300 Subject: remove the pip install stuff because it does not work as i hoped it would --- launch.py | 5 ----- requirements_versions.txt | 1 - webui.bat | 13 +++++++++++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/launch.py b/launch.py index f578c1c7..9d6f4a8c 100644 --- a/launch.py +++ b/launch.py @@ -188,8 +188,6 @@ def run_extensions_installers(settings_file): def prepare_environment(): global skip_install - pip_installer_location = os.environ.get('PIP_INSTALLER_LOCATION', None) - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -231,9 +229,6 @@ def prepare_environment(): print(f"Python {sys.version}") print(f"Commit hash: {commit}") - if pip_installer_location is not None and not is_installed("pip"): - run(f'"{python}" "{pip_installer_location}"', "Installing pip", "Couldn't install pip") - if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) diff --git a/requirements_versions.txt b/requirements_versions.txt index 1c328d44..135908be 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -4,7 +4,6 @@ accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 gradio==3.16.2 -fastapi==0.82.0 numpy==1.23.3 Pillow==9.4.0 realesrgan==0.3.0 diff --git a/webui.bat b/webui.bat index 0d6865c9..209d972b 100644 --- a/webui.bat +++ b/webui.bat @@ -9,10 +9,19 @@ set ERROR_REPORTING=FALSE mkdir tmp 2>NUL %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt -if %ERRORLEVEL% == 0 goto :start_venv +if %ERRORLEVEL% == 0 goto :check_pip echo Couldn't launch python goto :show_stdout_stderr +:check_pip +%PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt +if %ERRORLEVEL% == 0 goto :start_venv +if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr +%PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt +if %ERRORLEVEL% == 0 goto :start_venv +echo Couldn't install pip +goto :show_stdout_stderr + :start_venv if ["%VENV_DIR%"] == ["-"] goto :skip_venv if ["%SKIP_VENV%"] == ["1"] goto :skip_venv @@ -46,7 +55,7 @@ pause exit /b :accelerate_launch -echo "Accelerating" +echo Accelerating %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py pause exit /b -- cgit v1.2.3 From 84d9ce30cb427759547bc7876ed80ab91787d175 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 24 Jan 2023 23:51:45 -0500 Subject: Add option for float32 sampling with float16 UNet This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half(). --- README.md | 1 + modules/deepbooru_model.py | 4 +++- modules/devices.py | 2 ++ modules/processing.py | 15 ++++++++------- modules/sd_hijack_unet.py | 29 +++++++++++++++++++++++++++++ modules/sd_hijack_utils.py | 28 ++++++++++++++++++++++++++++ modules/sd_models.py | 10 ++++++++++ modules/shared.py | 1 + 8 files changed, 82 insertions(+), 8 deletions(-) create mode 100644 modules/sd_hijack_utils.py diff --git a/README.md b/README.md index 9c0cd1ef..a5611671 100644 --- a/README.md +++ b/README.md @@ -157,4 +157,5 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru - Security advice - RyotaK - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. +- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) - (You) diff --git a/modules/deepbooru_model.py b/modules/deepbooru_model.py index edd40c81..83d2ff09 100644 --- a/modules/deepbooru_model.py +++ b/modules/deepbooru_model.py @@ -2,6 +2,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +from modules import devices + # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more @@ -196,7 +198,7 @@ class DeepDanbooruModel(nn.Module): t_358, = inputs t_359 = t_358.permute(*[0, 3, 1, 2]) t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0) - t_360 = self.n_Conv_0(t_359_padded) + t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded) t_361 = F.relu(t_360) t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf')) t_362 = self.n_MaxPool_0(t_361) diff --git a/modules/devices.py b/modules/devices.py index 524ec7af..0981ef80 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -79,6 +79,8 @@ cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 +dtype_unet = torch.float16 +unet_needs_upcast = False def randn(seed, shape): diff --git a/modules/processing.py b/modules/processing.py index bc541e2f..2d186ba0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,7 +172,8 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image)) + conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -203,7 +204,7 @@ class StableDiffusionProcessing: # Create another latent image, this time with a masked version of the original input. # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. - conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype) + conditioning_mask = conditioning_mask.to(device=source_image.device, dtype=source_image.dtype) conditioning_image = torch.lerp( source_image, source_image * (1.0 - conditioning_mask), @@ -211,7 +212,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -225,10 +226,10 @@ class StableDiffusionProcessing: # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely # identify itself with a field common to all models. The conditioning_key is also hybrid. if isinstance(self.sd_model, LatentDepth2ImageDiffusion): - return self.depth2img_image_conditioning(source_image) + return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) if self.sampler.conditioning_key in {'hybrid', 'concat'}: - return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) @@ -610,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(): + with devices.autocast(disable=devices.unet_needs_upcast): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] @@ -988,7 +989,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(shared.device) + image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 18daf8c1..88c94e54 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -1,4 +1,8 @@ import torch +from packaging import version + +from modules import devices +from modules.sd_hijack_utils import CondFunc class TorchHijackForUnet: @@ -28,3 +32,28 @@ class TorchHijackForUnet: th = TorchHijackForUnet() + + +# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling +def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): + for y in cond.keys(): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + with devices.autocast(): + return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() + +class GELUHijack(torch.nn.GELU, torch.nn.Module): + def __init__(self, *args, **kwargs): + torch.nn.GELU.__init__(self, *args, **kwargs) + def forward(self, x): + if devices.unet_needs_upcast: + return torch.nn.GELU.forward(self.float(), x.float()).to(devices.dtype_unet) + else: + return torch.nn.GELU.forward(self, x) + +unet_needs_upcast = lambda *args, **kwargs: devices.unet_needs_upcast +CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.apply_model', apply_model, unet_needs_upcast) +CondFunc('ldm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).to(devices.dtype_unet), unet_needs_upcast) +if version.parse(torch.__version__) <= version.parse("1.13.1"): + CondFunc('ldm.modules.diffusionmodules.util.GroupNorm32.forward', lambda orig_func, self, *args, **kwargs: orig_func(self.float(), *args, **kwargs), unet_needs_upcast) + CondFunc('ldm.modules.attention.GEGLU.forward', lambda orig_func, self, x: orig_func(self.float(), x.float()).to(devices.dtype_unet), unet_needs_upcast) + CondFunc('open_clip.transformer.ResidualAttentionBlock.__init__', lambda orig_func, *args, **kwargs: kwargs.update({'act_layer': GELUHijack}) and False or orig_func(*args, **kwargs), lambda _, *args, **kwargs: kwargs.get('act_layer') is None or kwargs['act_layer'] == torch.nn.GELU) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py new file mode 100644 index 00000000..f81b169a --- /dev/null +++ b/modules/sd_hijack_utils.py @@ -0,0 +1,28 @@ +import importlib + +class CondFunc: + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-2, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..7c98991a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): if not shared.cmd_opts.no_half: vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 if shared.cmd_opts.no_half_vae: model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None model.half() model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) @@ -372,6 +380,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True timer = Timer() diff --git a/modules/shared.py b/modules/shared.py index 5f713bee..4ce1209b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -45,6 +45,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") +parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") -- cgit v1.2.3 From e3b53fd295aca784253dfc8668ec87b537a72f43 Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 25 Jan 2023 00:23:10 -0500 Subject: Add UI setting for upcasting attention to float32 Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers. In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also. --- modules/devices.py | 6 +- modules/processing.py | 2 +- modules/sd_hijack_optimizations.py | 159 +++++++++++++++++++++++-------------- modules/shared.py | 1 + modules/sub_quadratic_attention.py | 4 +- 5 files changed, 108 insertions(+), 64 deletions(-) diff --git a/modules/devices.py b/modules/devices.py index 0981ef80..6b36622c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -108,6 +108,10 @@ def autocast(disable=False): return torch.autocast("cuda") +def without_autocast(disable=False): + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + + class NansException(Exception): pass @@ -125,7 +129,7 @@ def test_for_nans(x, where): message = "A tensor with all NaNs was produced in Unet." if not shared.cmd_opts.no_half: - message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this." + message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this." elif where == "vae": message = "A tensor with all NaNs was produced in VAE." diff --git a/modules/processing.py b/modules/processing.py index 2d186ba0..a850082d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -611,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(disable=devices.unet_needs_upcast): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 74452709..c02d954c 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, errors +from modules import shared, errors, devices from modules.hypernetworks import hypernetwork from .sub_quadratic_attention import efficient_dot_product_attention @@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() - s2 = s1.softmax(dim=-1) - del s1 + with devices.without_autocast(disable=not shared.opts.upcast_attn): + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + del q, k, v - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None): k_in = self.to_k(context_k) v_in = self.to_v(context_v) - k_in *= self.scale - - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - mem_free_total = get_available_vram() - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + dtype = q_in.dtype + if shared.opts.upcast_attn: + q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k_in = k_in * self.scale + + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + mem_free_total = get_available_vram() + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k = self.to_k(context_k) * self.scale + k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r = einsum_op(q, k, v) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k = k * self.scale + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + r = r.to(dtype) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) # -- End of code from https://github.com/invoke-ai/InvokeAI -- @@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = x.to(dtype) + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) out_proj, dropout = self.to_out @@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ query_chunk_size = q_tokens kv_chunk_size = k_tokens - return efficient_dot_product_attention( - q, - k, - v, - query_chunk_size=q_chunk_size, - kv_chunk_size=kv_chunk_size, - kv_chunk_size_min = kv_chunk_size_min, - use_checkpoint=use_checkpoint, - ) + with devices.without_autocast(disable=q.dtype == v.dtype): + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) def get_xformers_flash_attention_op(q, k, v): @@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) + out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x): v = self.v(h_) b, c, h, w = q.shape q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() q = q.contiguous() k = k.contiguous() v = v.contiguous() out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 4ce1209b..6a0b96cb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -410,6 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 55052815..05595323 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -67,7 +67,7 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() exp_weights = torch.exp(attn_weights - max_score) - exp_values = torch.bmm(exp_weights, value) + exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype) max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) @@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking( ) attn_probs = attn_scores.softmax(dim=-1) del attn_scores - hidden_states_slice = torch.bmm(attn_probs, value) + hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype) return hidden_states_slice -- cgit v1.2.3 From 1bfec873fa13d803f3d4ac2a12bf6983838233fe Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 11:29:46 +0300 Subject: add an experimental option to apply loras to outputs rather than inputs --- extensions-builtin/Lora/lora.py | 5 ++++- extensions-builtin/Lora/scripts/lora_script.py | 7 ++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 137e58f7..cb8f1d36 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -166,7 +166,10 @@ def lora_forward(module, input, res): for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) if module is not None: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + if shared.opts.lora_apply_to_outputs and res.shape == input.shape: + res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + else: + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) return res diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 60b9eb64..544b228d 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -3,7 +3,7 @@ import torch import lora import extra_networks_lora import ui_extra_networks_lora -from modules import script_callbacks, ui_extra_networks, extra_networks +from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): @@ -28,3 +28,8 @@ torch.nn.Conv2d.forward = lora.lora_Conv2d_forward script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) script_callbacks.on_before_ui(before_ui) + + +shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { + "lora_apply_to_outputs": shared.OptionInfo(False, "Apply Lora to outputs rather than inputs when possible (experimental)"), +})) -- cgit v1.2.3 From ee0a0da3244123cb6d2ba4097a54a1e9caccb687 Mon Sep 17 00:00:00 2001 From: Kyle Date: Wed, 25 Jan 2023 08:53:23 -0500 Subject: Add instruct-pix2pix hijack Allows loading instruct-pix2pix models via same method as inpainting models in sd_models.py and sd_hijack_ip2p.py Adds ddpm_edit.py necessary for instruct-pix2pix --- modules/models/diffusion/ddpm_edit.py | 1459 +++++++++++++++++++++++++++++++++ modules/sd_hijack_ip2p.py | 13 + modules/sd_models.py | 12 +- 3 files changed, 1483 insertions(+), 1 deletion(-) create mode 100644 modules/models/diffusion/ddpm_edit.py create mode 100644 modules/sd_hijack_ip2p.py diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py new file mode 100644 index 00000000..f3d49c44 --- /dev/null +++ b/modules/models/diffusion/ddpm_edit.py @@ -0,0 +1,1459 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). +# See more details in LICENSE. + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + load_ema=True, + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + + if self.use_ema and load_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + # If initialing from EMA-only checkpoint, create EMA model after loading. + if self.use_ema and not load_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + + # Our model adds additional channels to the first layer to condition on an input image. + # For the first layer, copy existing channel weights and initialize new channel weights to zero. + input_keys = [ + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ] + + self_sd = self.state_dict() + for input_key in input_keys: + if input_key not in sd or input_key not in self_sd: + continue + + input_weight = self_sd[input_key] + + if input_weight.size() != sd[input_key].size(): + print(f"Manual init: {input_key}") + input_weight.zero_() + input_weight[:, :4, :, :].copy_(sd[input_key]) + ignore_keys.append(input_key) + + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + return batch[k] + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + load_ema=True, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, load_ema=load_ema, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + if self.use_ema and not load_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, uncond=0.05): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + cond_key = cond_key or self.cond_stage_key + xc = super().get_input(batch, cond_key) + if bs is not None: + xc["c_crossattn"] = xc["c_crossattn"][:bs] + xc["c_concat"] = xc["c_concat"][:bs] + cond = {} + + # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. + random = torch.rand(x.size(0), device=x.device) + prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1") + input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1") + + null_prompt = self.get_learned_conditioning([""]) + cond["c_crossattn"] = [torch.where(prompt_mask, null_prompt, self.get_learned_conditioning(xc["c_crossattn"]).detach())] + cond["c_concat"] = [input_mask * self.encode_first_stage((xc["c_concat"].to(self.device))).mode().detach()] + + out = [z, cond] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + # same as above but without decorator + def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + print(patch_limits_tknzd[0].shape) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + print(cut_cond.shape) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + print(adapted_cond.shape) + adapted_cond = self.get_learned_conditioning(adapted_cond) + print(adapted_cond.shape) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + print(adapted_cond.shape) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs): + + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size, + shape,cond,verbose=False,**kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True,**kwargs) + + return samples, intermediates + + + @torch.no_grad() + def log_images(self, batch, N=4, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=False, plot_denoise_rows=False, plot_progressive_rows=False, + plot_diffusion_rows=False, **kwargs): + + use_ddim = False + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N, uncond=0) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reals"] = xc["c_concat"] + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"]) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with self.ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with self.ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + with self.ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with self.ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py new file mode 100644 index 00000000..635f015f --- /dev/null +++ b/modules/sd_hijack_ip2p.py @@ -0,0 +1,13 @@ +import collections +import os.path +import sys +import gc +import time + +def should_hijack_ip2p(checkpoint_info): + from modules import sd_models + + ckpt_basename = os.path.basename(checkpoint_info.filename).lower() + cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() + + return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..cddc2343 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting +from modules.sd_hijack_ip2p import should_hijack_ip2p model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -365,6 +366,15 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None + if should_hijack_ip2p(checkpoint_info): + sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" + sd_config.model.params.conditioning_key = "hybrid" + sd_config.model.params.first_stage_key = "edited" + sd_config.model.params.cond_stage_key = "edit" + sd_config.model.params.image_size = 16 + sd_config.model.params.unet_config.params.in_channels = 8 + sd_config.model.params.unet_config.params.out_channels = 4 + if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -429,7 +439,7 @@ def reload_model_weights(sd_model=None, info=None): checkpoint_config = find_checkpoint_config(current_checkpoint_info) - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.3 From bd9b55ee908c43fb1b654b3a3a1320545023ce1c Mon Sep 17 00:00:00 2001 From: Kyle Date: Wed, 25 Jan 2023 09:41:41 -0500 Subject: Update requirements transformers==4.25.1 Update requirement for transformers to version 4.25.1 to allow instruct-pix2pix demo code to work --- requirements.txt | 2 +- requirements_versions.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index a4be1ec3..6d53f089 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ pytorch_lightning==1.7.7 realesrgan scikit-image>=0.19 timm==0.4.12 -transformers==4.19.2 +transformers==4.25.1 torch einops jsonmerge diff --git a/requirements_versions.txt b/requirements_versions.txt index 135908be..eaa08806 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,5 +1,5 @@ blendmodes==2022 -transformers==4.19.2 +transformers==4.25.1 accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 -- cgit v1.2.3 From 57c1baa774d07060af0abbd2974c5f36c8cb63ac Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 18:56:23 +0300 Subject: change to code for live preview fix on OSX to be bit more obvious --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index 3bd590ba..57c3db1b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -568,8 +568,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) - if shared.opts.live_previews_enable and sd_samplers.approximation_indexes.get(shared.opts.show_progress_type, 0) == 1: - # preload approx nn model before sampling for a more deterministic result + # for OSX, loading the model during sampling changes the generated picture, so it is loaded here + if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN": sd_vae_approx.model() if not p.disable_extra_networks: -- cgit v1.2.3 From 635499e8329dfd8c4c5ccca180881867f34a9f36 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 19:42:26 +0300 Subject: add pix2pix credits --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a5611671..c6bd6f27 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,8 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch - xformers - https://github.com/facebookresearch/xformers - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru +- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) +- Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix - Security advice - RyotaK - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. -- Sampling in float32 precision from a float16 UNet - marunine for the idea, Birch-san for the example Diffusers implementation (https://github.com/Birch-san/diffusers-play/tree/92feee6) - (You) -- cgit v1.2.3 From e179b6098ac1b1ce9645fef5bd9fd0bc9b918f30 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Wed, 25 Jan 2023 08:48:40 -0800 Subject: allow symlinks in the textual inversion embeddings folder --- modules/textual_inversion/textual_inversion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4e90f690..6cf00e65 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -194,7 +194,7 @@ class EmbeddingDatabase: if not os.path.isdir(embdir.path): return - for root, dirs, fns in os.walk(embdir.path): + for root, dirs, fns in os.walk(embdir.path, followlinks=True): for fn in fns: try: fullfn = os.path.join(root, fn) -- cgit v1.2.3 From 789d47f832a5c921dbbdd0a657dff9bca7f78d94 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 19:55:31 +0300 Subject: make clicking extra networks button one more time close the extra networks UI --- modules/ui_extra_networks.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 8b4f97f8..c6ff889a 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -117,8 +117,13 @@ def create_ui(container, button, tabname): ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) - button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) + def toggle_visibility(is_visible): + is_visible = not is_visible + return is_visible, gr.update(visible=is_visible) + + state_visible = gr.State(value=False) + button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) + button_close.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container]) def refresh(): res = [] -- cgit v1.2.3 From e425b9812b067073eb6edfafac689735f5391b45 Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Wed, 25 Jan 2023 22:07:48 +0500 Subject: Added Python version check --- launch.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/launch.py b/launch.py index 9d6f4a8c..86b4a32b 100644 --- a/launch.py +++ b/launch.py @@ -17,6 +17,17 @@ stored_commit_hash = None skip_install = False +def check_python_version(): + version = sys.version_info + version_range = None + if os.name == "nt": + version_range = range(7, 11) + else: + version_range = range(7, 12) + + assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/" + + def commit_hash(): global stored_commit_hash @@ -321,5 +332,6 @@ def start(): if __name__ == "__main__": + check_python_version() prepare_environment() start() -- cgit v1.2.3 From 15e89ef0f6f22f823c19592a401b9e4ee477258c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 20:11:01 +0300 Subject: fix for unet hijack breaking the train tab --- modules/sd_hijack_unet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index 88c94e54..a6ee577c 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -36,8 +36,11 @@ th = TorchHijackForUnet() # Below are monkey patches to enable upcasting a float16 UNet for float32 sampling def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): - for y in cond.keys(): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + + if isinstance(cond, dict): + for y in cond.keys(): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + with devices.autocast(): return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() -- cgit v1.2.3 From 57096823fadbc18b33d9b89d2d3a02d5ebba29f4 Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Wed, 25 Jan 2023 22:33:35 +0500 Subject: Remove a stacktrace from an assertion to not scare people --- launch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 86b4a32b..cf747e72 100644 --- a/launch.py +++ b/launch.py @@ -25,7 +25,10 @@ def check_python_version(): else: version_range = range(7, 12) - assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/" + try: + assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/" + except AssertionError as e: + print(e) def commit_hash(): -- cgit v1.2.3 From 2de99d62dd80123bf2d7dcbb2c4970fad5d92d42 Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Wed, 25 Jan 2023 22:38:28 +0500 Subject: some clarification --- launch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index cf747e72..4608bc81 100644 --- a/launch.py +++ b/launch.py @@ -26,9 +26,10 @@ def check_python_version(): version_range = range(7, 12) try: - assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/" + assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first." except AssertionError as e: print(e) + sys.exit(-1) def commit_hash(): -- cgit v1.2.3 From 0cc5f380d5a21625413554a6a64b97172b36d64a Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Wed, 25 Jan 2023 22:41:51 +0500 Subject: even more clarifications(?) i have no idea what commit message should be --- launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/launch.py b/launch.py index 4608bc81..801e0371 100644 --- a/launch.py +++ b/launch.py @@ -26,7 +26,7 @@ def check_python_version(): version_range = range(7, 12) try: - assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first." + assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too." except AssertionError as e: print(e) sys.exit(-1) -- cgit v1.2.3 From f5d73b6a6646b51027c8e6f6c6154f21b58d6af2 Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Wed, 25 Jan 2023 22:56:09 +0500 Subject: Fixed typo --- launch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/launch.py b/launch.py index 801e0371..e39c68e7 100644 --- a/launch.py +++ b/launch.py @@ -21,9 +21,9 @@ def check_python_version(): version = sys.version_info version_range = None if os.name == "nt": - version_range = range(7, 11) + version_range = range(7 + 1, 10 + 1) else: - version_range = range(7, 12) + version_range = range(7 + 1, 11 + 1) try: assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too." -- cgit v1.2.3 From e0df864b8c1f99d7d65d56c6ac8a1e5e314dddba Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 25 Jan 2023 13:19:06 -0500 Subject: Update arguments to use --upcast-sampling --- webui-macos-env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 95ca9c55..fa187dd1 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -10,7 +10,7 @@ then fi export install_dir="$HOME" -export COMMANDLINE_ARGS="--skip-torch-cuda-test --no-half --use-cpu interrogate" +export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --use-cpu interrogate" export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" -- cgit v1.2.3 From d1d6ce29831d1b067801c3206f314258de88f683 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 23:25:25 +0300 Subject: add edit_image_conditioning from my earlier edits in case there's an attempt to inegrate pix2pix properly this allows to use pix2pix model in img2img though it won't work well this way --- modules/processing.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/modules/processing.py b/modules/processing.py index 9e5a2f38..cb41288a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -185,7 +185,12 @@ class StableDiffusionProcessing: conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1. return conditioning - def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None): + def edit_image_conditioning(self, source_image): + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + + return conditioning_image + + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -228,6 +233,9 @@ class StableDiffusionProcessing: if isinstance(self.sd_model, LatentDepth2ImageDiffusion): return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image) + if self.sd_model.cond_stage_key == "edit": + return self.edit_image_conditioning(source_image) + if self.sampler.conditioning_key in {'hybrid', 'concat'}: return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask) -- cgit v1.2.3 From 6cff4401824299a983c8e13424018efc347b4a2b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 25 Jan 2023 23:25:40 +0300 Subject: fix prompt editing break after first batch in img2img --- modules/sd_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 6261d1f7..a7910b56 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -454,7 +454,7 @@ class KDiffusionSampler: def initialize(self, p): self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap.step = 0 + self.model_wrap_cfg.step = 0 self.eta = p.eta or opts.eta_ancestral k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) -- cgit v1.2.3 From d82d471bf7797fe09dbc6f3a6ac1c2a76142c8f1 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Thu, 26 Jan 2023 02:09:14 +0300 Subject: Ask user to clarify conditions --- .github/ISSUE_TEMPLATE/bug_report.yml | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index ed372f22..7d435297 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -37,20 +37,20 @@ body: id: what-should attributes: label: What should have happened? - description: tell what you think the normal behavior should be + description: Tell what you think the normal behavior should be validations: required: true - type: input id: commit attributes: label: Commit where the problem happens - description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit hash** shown in the cmd/terminal when you launch the UI) + description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) validations: required: true - type: dropdown id: platforms attributes: - label: What platforms do you use to access UI ? + label: What platforms do you use to access the UI ? multiple: true options: - Windows @@ -74,10 +74,27 @@ body: id: cmdargs attributes: label: Command Line Arguments - description: Are you using any launching parameters/command line arguments (modified webui-user.py) ? If yes, please write them below + description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise. render: Shell + validations: + required: true + - type: textarea + id: extensions + attributes: + label: List of extensions + description: Are you using any extensions other than built-ins? If yes, provide a list, you can copy it at "Extensions" tab. Write "No" otherwise. + validations: + required: true + - type: textarea + id: logs + attributes: + label: Console logs + description: Please provide **full** cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service. + render: Shell + validations: + required: true - type: textarea id: misc attributes: - label: Additional information, context and logs - description: Please provide us with any relevant additional info, context or log output. + label: Additional information + description: Please provide us with any relevant additional info or context. -- cgit v1.2.3 From 10421f93c3f7f7ce88cb40391b46d4e6664eff74 Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 26 Jan 2023 00:34:38 -0500 Subject: Fix full previews, --no-half-vae --- modules/processing.py | 8 ++++---- modules/sd_hijack_utils.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/processing.py b/modules/processing.py index cb41288a..92894d67 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -172,7 +172,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image)) conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), @@ -217,7 +217,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -417,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x): with devices.autocast(disable=x.dtype == devices.dtype_vae): - x = model.decode_first_stage(x) + x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x) return x @@ -1001,7 +1001,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images) image = 2. * image - 1. - image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None) + image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None) self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) diff --git a/modules/sd_hijack_utils.py b/modules/sd_hijack_utils.py index f81b169a..f8684475 100644 --- a/modules/sd_hijack_utils.py +++ b/modules/sd_hijack_utils.py @@ -5,7 +5,7 @@ class CondFunc: self = super(CondFunc, cls).__new__(cls) if isinstance(orig_func, str): func_path = orig_func.split('.') - for i in range(len(func_path)-2, -1, -1): + for i in range(len(func_path)-1, -1, -1): try: resolved_obj = importlib.import_module('.'.join(func_path[:i])) break -- cgit v1.2.3 From 1619233a747830887831cfea2f05fe826fce1bed Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Thu, 26 Jan 2023 12:52:44 +0500 Subject: Only Linux will have max 3.11 --- launch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/launch.py b/launch.py index e39c68e7..52f3bd52 100644 --- a/launch.py +++ b/launch.py @@ -20,10 +20,10 @@ skip_install = False def check_python_version(): version = sys.version_info version_range = None - if os.name == "nt": - version_range = range(7 + 1, 10 + 1) - else: + if platform.system() == "Linux": version_range = range(7 + 1, 11 + 1) + else: + version_range = range(7 + 1, 10 + 1) try: assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too." -- cgit v1.2.3 From f4ec411f2c9d6bc6817a2eca8a2c00f255ffb386 Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 03:45:16 -0500 Subject: Allow checkpoint merger to merge pix2pix models in the same way that it currently supports inpainting models. --- modules/extras.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 36123aa5..67ffdee3 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -132,6 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False + result_is_pix2pix_model = False if theta_func2: shared.state.textinfo = f"Loading B" @@ -186,13 +187,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if a.shape[1] == 4 and b.shape[1] == 9: raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True + if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model... + print("Detected possible merge of instruct model with non-instruct model.") + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch. + result_is_pix2pix_model = True + else: + assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" + theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) + result_is_inpainting_model = True else: theta_0[key] = theta_func2(a, b, multiplier) - + theta_0[key] = to_half(theta_0[key], save_as_half) shared.state.sampling_step += 1 @@ -226,6 +231,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ filename = filename_generator() if custom_name == '' else custom_name filename += ".inpainting" if result_is_inpainting_model else "" + filename += ".pix2pix" if result_is_pix2pix_model else "" filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) -- cgit v1.2.3 From f90798c6b6cc48e514acb08ce02bdb5874bf74d8 Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 04:38:04 -0500 Subject: Added error check for the rare case a user merges a pix2pix model with a normal model using weighted sum. Also removed bad print message that interfered with merging progress bar. --- modules/extras.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index 67ffdee3..badd13c7 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -186,9 +186,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: if a.shape[1] == 4 and b.shape[1] == 9: raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") + if a.shape[1] == 4 and b.shape[1] == 8: + raise RuntimeError("When merging pix2pix model with a normal one, A must be the pix2pix model.") if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model... - print("Detected possible merge of instruct model with non-instruct model.") theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch. result_is_pix2pix_model = True else: -- cgit v1.2.3 From 9e72dc743480c8b1ca6aeb8ced3af03f3e3243a3 Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 06:05:40 -0500 Subject: Changed all references to "pix2pix" to the more precise name "instruct pix2pix". Also changed extension to instrpix2pix at least for now. --- modules/extras.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index badd13c7..2bf0d17e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -132,7 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False - result_is_pix2pix_model = False + result_is_instruct_pix2pix_model = False if theta_func2: shared.state.textinfo = f"Loading B" @@ -187,11 +187,11 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if a.shape[1] == 4 and b.shape[1] == 9: raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") if a.shape[1] == 4 and b.shape[1] == 8: - raise RuntimeError("When merging pix2pix model with a normal one, A must be the pix2pix model.") + raise RuntimeError("When merging instruct-pix2pix model with a normal one, A must be the instruct-pix2pix model.") - if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model... + if a.shape[1] == 8 and b.shape[1] == 4:#If we have an Instruct-Pix2Pix model... theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common. Otherwise we get an error due to dimension mismatch. - result_is_pix2pix_model = True + result_is_instruct_pix2pix_model = True else: assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) @@ -232,7 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ filename = filename_generator() if custom_name == '' else custom_name filename += ".inpainting" if result_is_inpainting_model else "" - filename += ".pix2pix" if result_is_pix2pix_model else "" + filename += ".instrpix2pix" if result_is_instruct_pix2pix_model else "" filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) -- cgit v1.2.3 From cdc2fa209a3efdc71a90643a5e7a1df49869cd5f Mon Sep 17 00:00:00 2001 From: "ULTRANOX\\Chris" Date: Thu, 26 Jan 2023 11:27:07 -0500 Subject: Changed filename addition from "instrpix2pix" to the more readable ".instruct-pix2pix" for newly generated instruct pix2pix models. --- modules/extras.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/extras.py b/modules/extras.py index 2bf0d17e..466ecc15 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -232,7 +232,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ filename = filename_generator() if custom_name == '' else custom_name filename += ".inpainting" if result_is_inpainting_model else "" - filename += ".instrpix2pix" if result_is_instruct_pix2pix_model else "" + filename += ".instruct-pix2pix" if result_is_instruct_pix2pix_model else "" filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) -- cgit v1.2.3 From 7a14c8ab45da8a681792a6331d48a88dd684a0a9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 26 Jan 2023 23:29:27 +0300 Subject: add an option to enable sections from extras tab in txt2img/img2img fix some style inconsistenices --- modules/processing.py | 7 +++++- modules/scripts.py | 32 ++++++++++++++++++++++---- modules/scripts_auto_postprocessing.py | 42 ++++++++++++++++++++++++++++++++++ modules/scripts_postprocessing.py | 11 ++++++--- modules/shared.py | 15 ++++-------- modules/shared_items.py | 10 ++++++++ modules/ui_components.py | 8 +++++++ scripts/postprocessing_upscale.py | 25 ++++++++++++++++++++ style.css | 6 +---- 9 files changed, 133 insertions(+), 23 deletions(-) create mode 100644 modules/scripts_auto_postprocessing.py create mode 100644 modules/shared_items.py diff --git a/modules/processing.py b/modules/processing.py index 92894d67..262806a1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -658,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image = Image.fromarray(x_sample) + if p.scripts is not None: + pp = scripts.PostprocessImageArgs(image) + p.scripts.postprocess_image(p, pp) + image = pp.image + if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) diff --git a/modules/scripts.py b/modules/scripts.py index 03907a63..6e9dc0c0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -6,12 +6,16 @@ from collections import namedtuple import gradio as gr -from modules.processing import StableDiffusionProcessing from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing AlwaysVisible = object() +class PostprocessImageArgs: + def __init__(self, image): + self.image = image + + class Script: filename = None args_from = None @@ -65,7 +69,7 @@ class Script: args contains all values returned by components from ui() """ - raise NotImplementedError() + pass def process(self, p, *args): """ @@ -100,6 +104,13 @@ class Script: pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): + """ + Called for every image after it has been generated. + """ + + pass + def postprocess(self, p, processed, *args): """ This function is called after processing ends for AlwaysVisible scripts. @@ -247,11 +258,15 @@ class ScriptRunner: self.infotext_fields = [] def initialize_scripts(self, is_img2img): + from modules import scripts_auto_postprocessing + self.scripts.clear() self.alwayson_scripts.clear() self.selectable_scripts.clear() - for script_class, path, basedir, script_module in scripts_data: + auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data() + + for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data: script = script_class() script.filename = path script.is_txt2img = not is_img2img @@ -332,7 +347,7 @@ class ScriptRunner: return inputs - def run(self, p: StableDiffusionProcessing, *args): + def run(self, p, *args): script_index = args[0] if script_index == 0: @@ -386,6 +401,15 @@ class ScriptRunner: print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def postprocess_image(self, p, pp: PostprocessImageArgs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_image(p, pp, *script_args) + except Exception: + print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def before_component(self, component, **kwargs): for script in self.scripts: try: diff --git a/modules/scripts_auto_postprocessing.py b/modules/scripts_auto_postprocessing.py new file mode 100644 index 00000000..30d6d658 --- /dev/null +++ b/modules/scripts_auto_postprocessing.py @@ -0,0 +1,42 @@ +from modules import scripts, scripts_postprocessing, shared + + +class ScriptPostprocessingForMainUI(scripts.Script): + def __init__(self, script_postproc): + self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc + self.postprocessing_controls = None + + def title(self): + return self.script.name + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + self.postprocessing_controls = self.script.ui() + return self.postprocessing_controls.values() + + def postprocess_image(self, p, script_pp, *args): + args_dict = {k: v for k, v in zip(self.postprocessing_controls, args)} + + pp = scripts_postprocessing.PostprocessedImage(script_pp.image) + pp.info = {} + self.script.process(pp, **args_dict) + p.extra_generation_params.update(pp.info) + script_pp.image = pp.image + + +def create_auto_preprocessing_script_data(): + from modules import scripts + + res = [] + + for name in shared.opts.postprocessing_enable_in_main_ui: + script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None) + if script is None: + continue + + constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) + res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module)) + + return res diff --git a/modules/scripts_postprocessing.py b/modules/scripts_postprocessing.py index 25de02d0..ce0ebb61 100644 --- a/modules/scripts_postprocessing.py +++ b/modules/scripts_postprocessing.py @@ -46,6 +46,8 @@ class ScriptPostprocessing: pass + + def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: res = func(*args, **kwargs) @@ -68,6 +70,9 @@ class ScriptPostprocessingRunner: script: ScriptPostprocessing = script_class() script.filename = path + if script.name == "Simple Upscale": + continue + self.scripts.append(script) def create_script_ui(self, script, inputs): @@ -87,12 +92,11 @@ class ScriptPostprocessingRunner: import modules.scripts self.initialize_scripts(modules.scripts.postprocessing_scripts_data) - scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")] + scripts_order = shared.opts.postprocessing_operation_order def script_score(name): - name = name.lower() for i, possible_match in enumerate(scripts_order): - if possible_match in name: + if possible_match == name: return i return len(self.scripts) @@ -145,3 +149,4 @@ class ScriptPostprocessingRunner: def image_changed(self): for script in self.scripts_in_preferred_order(): script.image_changed() + diff --git a/modules/shared.py b/modules/shared.py index 6a0b96cb..cdeed55d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,8 +13,8 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors, ui_components -from modules.paths import models_path, script_path, sd_path +from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items +from modules.paths import models_path, script_path demo = None @@ -264,12 +264,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] - -def realesrgan_models_names(): - import modules.realesrgan_model - return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] - - class OptionInfo: def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): self.default = default @@ -360,7 +354,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), })) @@ -483,7 +477,8 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" })) options_templates.update(options_section(('postprocessing', "Postprocessing"), { - 'postprocessing_scipts_order': OptionInfo("upscale, gfpgan, codeformer", "Postprocessing operation order"), + 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), })) diff --git a/modules/shared_items.py b/modules/shared_items.py new file mode 100644 index 00000000..b5d480c9 --- /dev/null +++ b/modules/shared_items.py @@ -0,0 +1,10 @@ + + +def realesrgan_models_names(): + import modules.realesrgan_model + return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + +def postprocessing_scripts(): + import modules.scripts + + return modules.scripts.scripts_postproc.scripts \ No newline at end of file diff --git a/modules/ui_components.py b/modules/ui_components.py index 9aec3097..284ca0cf 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -48,3 +48,11 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" + +class DropdownMulti(gr.Dropdown): + """Same as gr.Dropdown but always multiselect""" + def __init__(self, **kwargs): + super().__init__(multiselect=True, **kwargs) + + def get_block_name(self): + return "dropdown" diff --git a/scripts/postprocessing_upscale.py b/scripts/postprocessing_upscale.py index 095d29b2..8842bd91 100644 --- a/scripts/postprocessing_upscale.py +++ b/scripts/postprocessing_upscale.py @@ -104,3 +104,28 @@ class ScriptPostprocessingUpscale(scripts_postprocessing.ScriptPostprocessing): def image_changed(self): upscale_cache.clear() + + +class ScriptPostprocessingUpscaleSimple(ScriptPostprocessingUpscale): + name = "Simple Upscale" + order = 900 + + def ui(self): + with FormRow(): + upscaler_name = gr.Dropdown(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name) + upscale_by = gr.Slider(minimum=0.05, maximum=8.0, step=0.05, label="Upscale by", value=2) + + return { + "upscale_by": upscale_by, + "upscaler_name": upscaler_name, + } + + def process(self, pp: scripts_postprocessing.PostprocessedImage, upscale_by=2.0, upscaler_name=None): + if upscaler_name is None or upscaler_name == "None": + return + + upscaler1 = next(iter([x for x in shared.sd_upscalers if x.name == upscaler_name]), None) + assert upscaler1, f'could not find upscaler named {upscaler_name}' + + pp.image = self.upscale(pp.image, pp.info, upscaler1, 0, upscale_by, 0, 0, False) + pp.info[f"Postprocess upscaler"] = upscaler1.name diff --git a/style.css b/style.css index ec046f78..dd914104 100644 --- a/style.css +++ b/style.css @@ -164,7 +164,7 @@ min-height: 3.2em; } -#txt2img_styles ul, #img2img_styles ul{ +ul.list-none{ max-height: 35em; z-index: 2000; } @@ -714,9 +714,6 @@ footer { white-space: nowrap; min-width: auto; } -#txt2img_hires_fix{ - margin-left: -0.8em; -} #img2img_copy_to_img2img, #img2img_copy_to_sketch, #img2img_copy_to_inpaint, #img2img_copy_to_inpaint_sketch{ margin-left: 0em; @@ -744,7 +741,6 @@ footer { .dark .gr-compact{ background-color: rgb(31 41 55 / var(--tw-bg-opacity)); - margin-left: 0.8em; } .gr-compact{ -- cgit v1.2.3 From d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:28:12 +0300 Subject: remove the need to place configs near models --- configs/instruct-pix2pix.yaml | 99 +++++++++++++++ configs/v1-inpainting-inference.yaml | 70 +++++++++++ modules/api/api.py | 5 +- modules/devices.py | 12 +- modules/sd_hijack_inpainting.py | 9 -- modules/sd_models.py | 228 +++++++++++++++++------------------ modules/sd_models_config.py | 65 ++++++++++ modules/shared.py | 7 +- modules/shared_items.py | 15 ++- modules/timer.py | 35 ++++++ v2-inference-v.yaml | 68 ----------- 11 files changed, 411 insertions(+), 202 deletions(-) create mode 100644 configs/instruct-pix2pix.yaml create mode 100644 configs/v1-inpainting-inference.yaml create mode 100644 modules/sd_models_config.py create mode 100644 modules/timer.py delete mode 100644 v2-inference-v.yaml diff --git a/configs/instruct-pix2pix.yaml b/configs/instruct-pix2pix.yaml new file mode 100644 index 00000000..437ddcef --- /dev/null +++ b/configs/instruct-pix2pix.yaml @@ -0,0 +1,99 @@ +# File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). +# See more details in LICENSE. + +model: + base_learning_rate: 1.0e-04 + target: modules.models.diffusion.ddpm_edit.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: edited + cond_stage_key: edit + # image_size: 64 + # image_size: 32 + image_size: 16 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: true + load_ema: true + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 0 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + +data: + target: main.DataModuleFromConfig + params: + batch_size: 128 + num_workers: 1 + wrap: false + validation: + target: edit_dataset.EditDataset + params: + path: data/clip-filtered-dataset + cache_dir: data/ + cache_name: data_10k + split: val + min_text_sim: 0.2 + min_image_sim: 0.75 + min_direction_sim: 0.2 + max_samples_per_prompt: 1 + min_resize_res: 512 + max_resize_res: 512 + crop_res: 512 + output_as_edit: False + real_input: True diff --git a/configs/v1-inpainting-inference.yaml b/configs/v1-inpainting-inference.yaml new file mode 100644 index 00000000..f9eec37d --- /dev/null +++ b/configs/v1-inpainting-inference.yaml @@ -0,0 +1,70 @@ +model: + base_learning_rate: 7.5e-05 + target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid # important + monitor: val/loss_simple_ema + scale_factor: 0.18215 + finetune_keys: null + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 9 # 4 data + 4 downscaled image + 1 mask + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPEmbedder diff --git a/modules/api/api.py b/modules/api/api.py index 25c65e57..eb7b1da5 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -18,7 +18,8 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, find_checkpoint_config +from modules.sd_models import checkpoints_list +from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices from typing import List @@ -387,7 +388,7 @@ class Api: ] def get_sd_models(self): - return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..2d5f797a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,14 +34,18 @@ def get_cuda_device_string(): return "cuda" -def get_optimal_device(): +def get_optimal_device_name(): if torch.cuda.is_available(): - return torch.device(get_cuda_device_string()) + return get_cuda_device_string() if has_mps(): - return torch.device("mps") + return "mps" + + return "cpu" - return cpu + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) def get_device_for(task): diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 31d2c898..478cd499 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -96,15 +96,6 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0, e_t -def should_hijack_inpainting(checkpoint_info): - from modules import sd_models - - ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() - - return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename - - def do_inpainting_hijack(): # p_sample_plms is needed because PLMS can't work with dicts as conditionings diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072eb2e..fa208728 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,8 +2,6 @@ import collections import os.path import sys import gc -import time -from collections import namedtuple import torch import re import safetensors.torch @@ -14,10 +12,10 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path -from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting -from modules.sd_hijack_ip2p import should_hijack_ip2p +from modules.sd_hijack_inpainting import do_inpainting_hijack +from modules.timer import Timer model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -99,17 +97,6 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) -def find_checkpoint_config(info): - if info is None: - return shared.cmd_opts.config - - config = os.path.splitext(info.filename)[0] + ".yaml" - if os.path.exists(config): - return config - - return shared.cmd_opts.config - - def list_models(): checkpoints_list.clear() checkpoint_alisases.clear() @@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - device = map_location or shared.weight_load_location - if device is None: - device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) @@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo): +def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): + sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + + if checkpoint_info in checkpoints_loaded: + # use checkpoint cache + print(f"Loading weights [{sd_model_hash}] from cache") + return checkpoints_loaded[checkpoint_info] + + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + res = read_state_dict(checkpoint_info.filename) + timer.record("load weights from disk") + + return res + + +def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + if checkpoint_info.title != title: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - cache_enabled = shared.opts.sd_checkpoint_cache > 0 + if state_dict is None: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if cache_enabled and checkpoint_info in checkpoints_loaded: - # use checkpoint cache - print(f"Loading weights [{sd_model_hash}] from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - else: - # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + model.load_state_dict(state_dict, strict=False) + del state_dict + timer.record("apply weights to model") - sd = read_state_dict(checkpoint_info.filename) - model.load_state_dict(sd, strict=False) - del sd - - if cache_enabled: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + if shared.opts.sd_checkpoint_cache > 0: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) + timer.record("apply channels_last") - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + if not shared.cmd_opts.no_half: + vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) - if not shared.cmd_opts.no_half: - vae = model.first_stage_model - depth_model = getattr(model, 'depth_model', None) + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None - # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: - model.first_stage_model = None - # with --upcast-sampling, don't convert the depth model weights to float16 - if shared.cmd_opts.upcast_sampling and depth_model: - model.depth_model = None + model.half() + model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model - model.half() - model.first_stage_model = vae - if depth_model: - model.depth_model = depth_model + timer.record("apply half()") - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - devices.dtype_unet = model.model.diffusion_model.dtype - devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 - model.first_stage_model.to(devices.dtype_vae) + model.first_stage_model.to(devices.dtype_vae) + timer.record("apply dtype to VAE") # clean up cache if limit is reached - if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model - checkpoints_loaded.popitem(last=False) # LRU + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename @@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_vae.clear_loaded_vae() vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) sd_vae.load_vae(model, vae_file, vae_source) + timer.record("load VAE") def enable_midas_autodownload(): @@ -340,24 +340,20 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper -class Timer: - def __init__(self): - self.start = time.time() +def repair_config(sd_config): - def elapsed(self): - end = time.time() - res = end - self.start - self.start = end - return res + if not hasattr(sd_config.model.params, "use_ema"): + sd_config.model.params.use_ema = False + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True -def load_model(checkpoint_info=None): + +def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - checkpoint_config = find_checkpoint_config(checkpoint_info) - - if checkpoint_config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -365,38 +361,27 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_config) - - if should_hijack_inpainting(checkpoint_info): - # Hardcoded config for now... - sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.unet_config.params.in_channels = 9 - sd_config.model.params.finetune_keys = None - - if should_hijack_ip2p(checkpoint_info): - sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.first_stage_key = "edited" - sd_config.model.params.cond_stage_key = "edit" - sd_config.model.params.image_size = 16 - sd_config.model.params.unet_config.params.in_channels = 8 - sd_config.model.params.unet_config.params.out_channels = 4 + do_inpainting_hijack() - if not hasattr(sd_config.model.params, "use_ema"): - sd_config.model.params.use_ema = False + timer = Timer() - do_inpainting_hijack() + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - timer = Timer() + timer.record("find config") - sd_model = None + sd_config = OmegaConf.load(checkpoint_config) + repair_config(sd_config) + + timer.record("load config") + + print(f"Creating model from config: {checkpoint_config}") + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) @@ -407,29 +392,35 @@ def load_model(checkpoint_info=None): print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) - elapsed_create = timer.elapsed() + sd_model.used_config = checkpoint_config - load_model_weights(sd_model, checkpoint_info) + timer.record("create model") - elapsed_load_weights = timer.elapsed() + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: sd_model.to(shared.device) + timer.record("move model to device") + sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + sd_model.eval() shared.sd_model = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + timer.record("load textual inversion embeddings") + script_callbacks.model_loaded_callback(sd_model) - elapsed_the_rest = timer.elapsed() + timer.record("scripts callbacks") - print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") + print(f"Model loaded in {timer.summary()}.") return sd_model @@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed current_checkpoint_info = None else: @@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - checkpoint_config = find_checkpoint_config(current_checkpoint_info) - - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): - del sd_model - checkpoints_loaded.clear() - load_model(checkpoint_info) - return shared.sd_model - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None): timer = Timer() + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + + timer.record("find config") + + if sd_model is None or checkpoint_config != sd_model.used_config: + del sd_model + checkpoints_loaded.clear() + load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"]) + return shared.sd_model + try: - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, state_dict, timer) except Exception as e: print("Failed to load checkpoint, restoring previous") - load_model_weights(sd_model, current_checkpoint_info) + load_model_weights(sd_model, current_checkpoint_info, None, timer) raise finally: sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) + timer.record("move model to device") - elapsed = timer.elapsed() - - print(f"Weights loaded in {elapsed:.1f}s.") + print(f"Weights loaded in {timer.summary()}.") return sd_model diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py new file mode 100644 index 00000000..ea773a10 --- /dev/null +++ b/modules/sd_models_config.py @@ -0,0 +1,65 @@ +import re +import os + +from modules import shared, paths + +sd_configs_path = shared.sd_configs_path +sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") + + +config_default = shared.sd_default_config +config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") +config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") +config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") +config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") + +re_parametrization_v = re.compile(r'-v\b') + + +def guess_model_config_from_state_dict(sd, filename): + fn = os.path.basename(filename) + + sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) + diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) + roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) + + if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: + if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: + return config_sd2v + else: + return config_sd2 + + if diffusion_model_input is not None: + if diffusion_model_input.shape[1] == 9: + return config_inpainting + if diffusion_model_input.shape[1] == 8: + return config_instruct_pix2pix + + if roberta_weight is not None: + return config_alt_diffusion + + return config_default + + +def find_checkpoint_config(state_dict, info): + if info is None: + return guess_model_config_from_state_dict(state_dict, "") + + config = find_checkpoint_config_near_filename(info) + if config is not None: + return config + + return guess_model_config_from_state_dict(state_dict, info.filename) + + +def find_checkpoint_config_near_filename(info): + if info is None: + return None + + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return None + diff --git a/modules/shared.py b/modules/shared.py index cdeed55d..14be993d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,13 +13,14 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading, errors, ui_components, shared_items +from modules import localization, extensions, script_loading, errors, ui_components, shared_items from modules.paths import models_path, script_path demo = None -sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") +sd_configs_path = os.path.join(script_path, "configs") +sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file @@ -391,7 +392,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), diff --git a/modules/shared_items.py b/modules/shared_items.py index b5d480c9..8b5ec96d 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -4,7 +4,20 @@ def realesrgan_models_names(): import modules.realesrgan_model return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] + def postprocessing_scripts(): import modules.scripts - return modules.scripts.scripts_postproc.scripts \ No newline at end of file + return modules.scripts.scripts_postproc.scripts + + +def sd_vae_items(): + import modules.sd_vae + + return ["Automatic", "None"] + list(modules.sd_vae.vae_dict) + + +def refresh_vae_list(): + import modules.sd_vae + + return modules.sd_vae.refresh_vae_list diff --git a/modules/timer.py b/modules/timer.py new file mode 100644 index 00000000..57a4f17a --- /dev/null +++ b/modules/timer.py @@ -0,0 +1,35 @@ +import time + + +class Timer: + def __init__(self): + self.start = time.time() + self.records = {} + self.total = 0 + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def record(self, category, extra_time=0): + e = self.elapsed() + if category not in self.records: + self.records[category] = 0 + + self.records[category] += e + extra_time + self.total += e + extra_time + + def summary(self): + res = f"{self.total:.1f}s" + + additions = [x for x in self.records.items() if x[1] >= 0.1] + if not additions: + return res + + res += " (" + res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions]) + res += ")" + + return res diff --git a/v2-inference-v.yaml b/v2-inference-v.yaml deleted file mode 100644 index 513cd635..00000000 --- a/v2-inference-v.yaml +++ /dev/null @@ -1,68 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - parameterization: "v" - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False # we set this to false because this is an inference only config - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - use_fp16: True - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_head_channels: 64 # need to fix for flash-attn - use_spatial_transformer: True - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - #attn_type: "vanilla-xformers" - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: - target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: True - layer: "penultimate" \ No newline at end of file -- cgit v1.2.3 From 6f31d2210c189f8db118e6f95add7ba2a64f0238 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:54:19 +0300 Subject: support detecting midas model fix broken api for checkpoint list --- modules/api/models.py | 2 +- modules/sd_models.py | 10 +++++----- modules/sd_models_config.py | 7 +++++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/modules/api/models.py b/modules/api/models.py index 805bd8f7..cba43d3b 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -228,7 +228,7 @@ class SDModelItem(BaseModel): hash: Optional[str] = Field(title="Short hash") sha256: Optional[str] = Field(title="sha256 hash") filename: str = Field(title="Filename") - config: str = Field(title="Config file") + config: Optional[str] = Field(title="Config file") class HypernetworkItem(BaseModel): name: str = Field(title="Name") diff --git a/modules/sd_models.py b/modules/sd_models.py index fa208728..37dad18d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -439,12 +439,12 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(sd_model) + sd_hijack.model_hijack.undo_hijack(sd_model) timer = Timer() diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index ea773a10..4d1e92e1 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -10,6 +10,7 @@ sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") config_alt_diffusion = os.path.join(sd_configs_path, "alt-diffusion-inference.yaml") @@ -22,7 +23,9 @@ def guess_model_config_from_state_dict(sd, filename): sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) - roberta_weight = sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) + + if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + return config_depth_model if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: @@ -36,7 +39,7 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix - if roberta_weight is not None: + if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: return config_alt_diffusion return config_default -- cgit v1.2.3 From 9beb794e0b0dc1a0f9e89d8e38bd789a8c608397 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 13:08:00 +0300 Subject: clarify the option to disable NaN check. --- modules/devices.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/devices.py b/modules/devices.py index 2d5f797a..4687944e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -143,6 +143,8 @@ def test_for_nans(x, where): else: message = "A tensor with all NaNs was produced." + message += " Use --disable-nan-check commandline argument to disable this check." + raise NansException(message) -- cgit v1.2.3 From 9ecf1e827c5966e11495a0c066a127defbba9bcc Mon Sep 17 00:00:00 2001 From: Spaceginner Date: Fri, 27 Jan 2023 17:35:24 +0500 Subject: Made it only a warning --- .gitignore | 1 + launch.py | 39 ++++++++++++++++++++++++++++----------- 2 files changed, 29 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 0b1d17ca..c8be9688 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ notification.mp3 /test/stdout.txt /test/stderr.txt /cache.json +no_py_ver_warning diff --git a/launch.py b/launch.py index 52f3bd52..4f5a4bc4 100644 --- a/launch.py +++ b/launch.py @@ -18,18 +18,35 @@ skip_install = False def check_python_version(): - version = sys.version_info - version_range = None - if platform.system() == "Linux": - version_range = range(7 + 1, 11 + 1) - else: - version_range = range(7 + 1, 10 + 1) + if not os.path.isfile("no_py_ver_warning"): + version = sys.version_info + version_range = None + if platform.system() == "Linux": + version_range = range(7 + 1, 11 + 1) + else: + version_range = range(7 + 1, 10 + 1) - try: - assert version.major == 3 and version.minor in version_range, "Unsupported Python version, please use Python 3.10.x instead. You can download latest release as of 25th January (3.10.9) from here: https://www.python.org/downloads/release/python-3109/. Please, make sure to first delete current version of Python first and delete `venv` folder inside of WebUI's folder, too." - except AssertionError as e: - print(e) - sys.exit(-1) + try: + assert version.major == 3 and version.minor in version_range, f""" +=== Warning === +This program was tested only with 3.10 Python, but you have {version.major}.{version.minor} Python. +If you encounter an error with "RuntimeError: Couldn't install torch." message, +or any other error regarding unsuccessful package (library) installation, +please downgrade (or upgrade) to the latest version of 3.10 Python +and delete current Python and "venv" folder in WebUI's directory. + +You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/ + +You will see this warning only once, delete file "no_py_ver_warning" file to show this warning again. +=== Warning === + +Press ENTER to continue...\ +""" + except AssertionError as e: + print(e) + with open("no_py_ver_warning", "w"): + pass + input() def commit_hash(): -- cgit v1.2.3 From 5eee2ac39863f9e44591b50d0710dd2615416a13 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 17:15:42 +0100 Subject: add data-dir flag and set all user data directories based on it --- modules/extensions.py | 2 +- modules/generation_parameters_copypaste.py | 4 ++-- modules/gfpgan_model.py | 5 ++--- modules/hashes.py | 4 +++- modules/interrogate.py | 2 +- modules/paths.py | 10 +++++++++- modules/processing.py | 3 ++- modules/sd_models.py | 6 +++--- modules/sd_vae.py | 5 ++--- modules/shared.py | 11 ++++++----- modules/textual_inversion/preprocess.py | 5 ++--- modules/ui.py | 6 +++--- modules/ui_extensions.py | 2 +- modules/upscaler.py | 5 ++--- 14 files changed, 39 insertions(+), 31 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index b522125c..92ee8144 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -7,7 +7,7 @@ import git from modules import paths, shared extensions = [] -extensions_dir = os.path.join(paths.script_path, "extensions") +extensions_dir = os.path.join(paths.data_path, "extensions") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 46e12dc6..35f72808 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re from pathlib import Path import gradio as gr -from modules.shared import script_path +from modules.paths import data_path, script_path from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image @@ -289,7 +289,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model def connect_paste(button, paste_fields, input_comp, jsfunc=None): def paste_func(prompt): if not prompt and not shared.cmd_opts.hide_ui_dir_config: - filename = os.path.join(script_path, "params.txt") + filename = os.path.join(data_path, "params.txt") if os.path.exists(filename): with open(filename, "r", encoding="utf8") as file: prompt = file.read() diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 1e2dbc32..fbe6215a 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -6,12 +6,11 @@ import facexlib import gfpgan import modules.face_restoration -from modules import shared, devices, modelloader -from modules.paths import models_path +from modules import paths, shared, devices, modelloader model_dir = "GFPGAN" user_path = None -model_path = os.path.join(models_path, model_dir) +model_path = os.path.join(paths.models_path, model_dir) model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" have_gfpgan = False loaded_gfpgan_model = None diff --git a/modules/hashes.py b/modules/hashes.py index b85a7580..819362a3 100644 --- a/modules/hashes.py +++ b/modules/hashes.py @@ -4,8 +4,10 @@ import os.path import filelock +from modules.paths import data_path -cache_filename = "cache.json" + +cache_filename = os.path.join(data_path, "cache.json") cache_data = None diff --git a/modules/interrogate.py b/modules/interrogate.py index c72ff694..cbb80683 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -12,7 +12,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode import modules.shared as shared -from modules import devices, paths, lowvram, modelloader, errors +from modules import devices, paths, shared, lowvram, modelloader, errors blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' diff --git a/modules/paths.py b/modules/paths.py index 20b3e4d8..08e6f9b9 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -4,7 +4,15 @@ import sys import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -models_path = os.path.join(script_path, "models") + +# Parse the --data-dir flag first so we can use it as a base for our other argument default values +parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) +cmd_opts_pre = parser.parse_known_args()[0] +data_path = cmd_opts_pre.data_dir +models_path = os.path.join(data_path, "models") + +# data_path = cmd_opts_pre.data sys.path.insert(0, script_path) # search for directory of stable diffusion in following places diff --git a/modules/processing.py b/modules/processing.py index 262806a1..5072fc40 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -17,6 +17,7 @@ from modules import devices, prompt_parser, masking, sd_samplers, lowvram, gener from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared +import modules.paths as paths import modules.face_restoration import modules.images as images import modules.styles @@ -584,7 +585,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if not p.disable_extra_networks: extra_networks.activate(p, extra_network_data) - with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: + with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) diff --git a/modules/sd_models.py b/modules/sd_models.py index 37dad18d..b2d48a51 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -12,13 +12,13 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) +model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) checkpoints_list = {} checkpoint_alisases = {} @@ -307,7 +307,7 @@ def enable_midas_autodownload(): location automatically. """ - midas_path = os.path.join(models_path, 'midas') + midas_path = os.path.join(paths.models_path, 'midas') # stable-diffusion-stability-ai hard-codes the midas model path to # a location that differs from where other scripts using this model look. diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 4ce238b8..9b00f76e 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,13 +3,12 @@ import safetensors.torch import os import collections from collections import namedtuple -from modules import shared, devices, script_callbacks, sd_models -from modules.paths import models_path +from modules import paths, shared, devices, script_callbacks, sd_models import glob from copy import deepcopy -vae_path = os.path.abspath(os.path.join(models_path, "VAE")) +vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} vae_dict = {} diff --git a/modules/shared.py b/modules/shared.py index 14be993d..474fcc42 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.memmon import modules.styles import modules.devices as devices from modules import localization, extensions, script_loading, errors, ui_components, shared_items -from modules.paths import models_path, script_path +from modules.paths import models_path, script_path, data_path demo = None @@ -25,6 +25,7 @@ sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file parser = argparse.ArgumentParser() +parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") @@ -35,7 +36,7 @@ parser.add_argument("--no-half", action='store_true', help="do not switch the mo parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") -parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--embeddings-dir", type=str, default=os.path.join(data_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") @@ -74,16 +75,16 @@ parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for sp parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) -parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) +parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(data_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) -parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) +parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(data_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") -parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) +parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index c0ac11d3..2239cb84 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -6,8 +6,7 @@ import sys import tqdm import time -from modules import shared, images, deepbooru -from modules.paths import models_path +from modules import paths, shared, images, deepbooru from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop @@ -199,7 +198,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre dnn_model_path = None try: - dnn_model_path = autocrop.download_and_cache_models(os.path.join(models_path, "opencv")) + dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) except Exception as e: print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..0117df3e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,7 +21,7 @@ from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_grad from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components, ui_common, ui_postprocessing from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path +from modules.paths import script_path, data_path from modules.shared import opts, cmd_opts, restricted_opts @@ -1497,8 +1497,8 @@ def create_ui(): with open(cssfile, "r", encoding="utf8") as file: css += file.read() + "\n" - if os.path.exists(os.path.join(script_path, "user.css")): - with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: + if os.path.exists(os.path.join(data_path, "user.css")): + with open(os.path.join(data_path, "user.css"), "r", encoding="utf8") as file: css += file.read() + "\n" if not cmd_opts.no_progressbar_hiding: diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 742e745e..66a41865 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -132,7 +132,7 @@ def install_extension_from_url(dirname, url): normalized_url = normalize_git_url(url) assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' - tmpdir = os.path.join(paths.script_path, "tmp", dirname) + tmpdir = os.path.join(paths.data_path, "tmp", dirname) try: shutil.rmtree(tmpdir, True) diff --git a/modules/upscaler.py b/modules/upscaler.py index a5bf5acb..e2eaa730 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -11,7 +11,6 @@ from modules import modelloader, shared LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) -from modules.paths import models_path class Upscaler: @@ -39,7 +38,7 @@ class Upscaler: self.mod_scale = None if self.model_path is None and self.name: - self.model_path = os.path.join(models_path, self.name) + self.model_path = os.path.join(shared.models_path, self.name) if self.model_path and create_dirs: os.makedirs(self.model_path, exist_ok=True) @@ -143,4 +142,4 @@ class UpscalerNearest(Upscaler): def __init__(self, dirname=None): super().__init__(False) self.name = "Nearest" - self.scalers = [UpscalerData("Nearest", None, self)] \ No newline at end of file + self.scalers = [UpscalerData("Nearest", None, self)] -- cgit v1.2.3 From 14c0884fd0948c478db165989cca7aaffc9a0504 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 17:55:59 +0100 Subject: use python importlib to load and execute extension modules previously module attributes like __file__ where not set correctly, leading to scripts getting the directory of the stable-diffusion repo location instead of their own script. This causes problem when loading user data from an external location using the --data-dir flag, as extensions would look for their own code in the stable-diffusion repo location instead of the data dir location. Using pythons importlib functions sets the modules specs correctly and executes them. But this will break extensions if they build paths based on the previously incorrect __file__ attribute. --- modules/script_loading.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modules/script_loading.py b/modules/script_loading.py index f93f0951..a7d2203f 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,16 +1,14 @@ import os import sys import traceback +import importlib.util from types import ModuleType def load_module(path): - with open(path, "r", encoding="utf8") as file: - text = file.read() - - compiled = compile(text, path, 'exec') - module = ModuleType(os.path.basename(path)) - exec(compiled, module.__dict__) + module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) + module = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(module) return module -- cgit v1.2.3 From 6b3981c0685cd1df750df4eb51823f1cfd70c6d5 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 18:00:09 +0100 Subject: clean up unused script_path imports --- modules/codeformer_model.py | 2 +- modules/generation_parameters_copypaste.py | 2 +- webui.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ab40d842..01fb7bd8 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -8,7 +8,7 @@ import torch import modules.face_restoration import modules.shared from modules import shared, devices, modelloader -from modules.paths import script_path, models_path +from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 35f72808..773c5c0e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re from pathlib import Path import gradio as gr -from modules.paths import data_path, script_path +from modules.paths import data_path from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image diff --git a/webui.py b/webui.py index e1565a8d..41f32f5c 100644 --- a/webui.py +++ b/webui.py @@ -15,7 +15,6 @@ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call -from modules.paths import script_path import torch -- cgit v1.2.3 From 23a9d5e27390846dea0895a02c04aec9583a4d38 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 18:18:55 +0100 Subject: create user extensions directory if not exists --- modules/extensions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modules/extensions.py b/modules/extensions.py index 92ee8144..5e12b1aa 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -10,6 +10,8 @@ extensions = [] extensions_dir = os.path.join(paths.data_path, "extensions") extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin") +if not os.path.exists(extensions_dir): + os.makedirs(extensions_dir) def active(): return [x for x in extensions if x.enabled] -- cgit v1.2.3 From eafaf14167cf574ad0f918c10f60ef86aea9cd20 Mon Sep 17 00:00:00 2001 From: Gazzoo-byte <73721238+Gazzoo-byte@users.noreply.github.com> Date: Fri, 27 Jan 2023 18:34:41 +0000 Subject: Add button to switch width and height Adds a button to switch width and height, allowing quick and easy switching between landscape and portrait. --- modules/ui.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..fb0e4d5c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -91,6 +91,13 @@ save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001F5D1' # 🗑️ extra_networks_symbol = '\U0001F3B4' # 🎴 +switch_values_symbol = '\U000021C5' # ⇅ + +def switch_width_and_height(width, height): + width_temp = width + width = height + height = width_temp + return width, height def plaintext_to_html(text): @@ -466,6 +473,7 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="txt2img_height") if opts.dimensions_and_batch_together: + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="txt2img_res_switch_btn") with gr.Column(elem_id="txt2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") @@ -566,6 +574,7 @@ def create_ui(): txt2img_prompt.submit(**txt2img_args) submit.click(**txt2img_args) + res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) txt_prompt_img.change( fn=modules.images.image_data, @@ -728,6 +737,7 @@ def create_ui(): height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="img2img_height") if opts.dimensions_and_batch_together: + res_switch_btn = ToolButton(value=switch_values_symbol, elem_id="img2img_res_switch_btn") with gr.Column(elem_id="img2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") @@ -865,6 +875,7 @@ def create_ui(): img2img_prompt.submit(**img2img_args) submit.click(**img2img_args) + res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) img2img_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), -- cgit v1.2.3 From a6a5bfb15531b19ce0319593d67d05a356f49a65 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Fri, 27 Jan 2023 13:48:39 -0500 Subject: deepcopy pc.styles, allows for multiple style axis --- scripts/xyz_grid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 828c2d12..f2fe506c 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -123,7 +123,7 @@ def apply_vae(p, x, xs): def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): - p.styles = x.split(',') + p.styles.extend(x.split(',')) def format_value_add_label(p, opt, x): @@ -533,6 +533,7 @@ class Script(scripts.Script): return Processed(p, [], p.seed, "") pc = copy(p) + pc.styles = pc.styles[:] x_opt.apply(pc, x, xs) y_opt.apply(pc, y, ys) z_opt.apply(pc, z, zs) -- cgit v1.2.3 From 32d389ef0f7c75dd85fc7aebe7bca279f36fed86 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Fri, 27 Jan 2023 14:04:23 -0500 Subject: changes remaining text from X/Y -> X/Y/Z --- README.md | 2 +- javascript/hints.js | 2 +- scripts/xyz_grid.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index c6bd6f27..2149dcc5 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ A browser interface based on Gradio library for Stable Diffusion. - a man in a (tuxedo:1.21) - alternative syntax - select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user) - Loopback, run img2img processing multiple times -- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters +- X/Y/Z plot, a way to draw a 3 dimensional plot of images with different parameters - Textual Inversion - have as many embeddings as you want and use any names you like for them - use multiple embeddings with different numbers of vectors per token diff --git a/javascript/hints.js b/javascript/hints.js index 3cf10e20..7b60b25e 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -50,7 +50,7 @@ titles = { "None": "Do not do anything special", "Prompt matrix": "Separate prompts into parts using vertical pipe character (|) and the script will create a picture for every combination of them (except for the first part, which will be present in all combinations)", - "X/Y plot": "Create a grid where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows", + "X/Y/Z plot": "Create grid(s) where images will have different parameters. Use inputs below to specify which parameters will be shared by columns and rows", "Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work", "Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others", diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index f2fe506c..f0116055 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -499,7 +499,7 @@ class Script(scripts.Script): image_cell_count = p.n_iter * p.batch_size cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" plural_s = 's' if len(zs) > 1 else '' - print(f"X/Y plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})") + print(f"X/Y/Z plot will create {len(xs) * len(ys) * len(zs) * image_cell_count} images on {len(zs)} {len(xs)}x{len(ys)} grid{plural_s}{cell_console_text}. (Total steps to process: {total_steps})") shared.total_tqdm.updateTotal(total_steps) grid_infotext = [None] -- cgit v1.2.3 From cc8c9b7474d917888a0bd069fcd59a458c67ae4b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 22:43:08 +0300 Subject: fix broken calls to find_checkpoint_config --- modules/extras.py | 4 ++-- modules/sd_hijack_ip2p.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/extras.py b/modules/extras.py index 36123aa5..4f842be9 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -6,7 +6,7 @@ import shutil import torch import tqdm -from modules import shared, images, sd_models, sd_vae +from modules import shared, images, sd_models, sd_vae, sd_models_config from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch @@ -37,7 +37,7 @@ def run_pnginfo(image): def create_config(ckpt_result, config_source, a, b, c): def config(x): - res = sd_models.find_checkpoint_config(x) if x else None + res = sd_models_config.find_checkpoint_config_near_filename(x) if x else None return res if res != shared.sd_default_config else None if config_source == 0: diff --git a/modules/sd_hijack_ip2p.py b/modules/sd_hijack_ip2p.py index 635f015f..3c727d3b 100644 --- a/modules/sd_hijack_ip2p.py +++ b/modules/sd_hijack_ip2p.py @@ -5,9 +5,9 @@ import gc import time def should_hijack_ip2p(checkpoint_info): - from modules import sd_models + from modules import sd_models_config ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() + cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower() return "pix2pix" in ckpt_basename and not "pix2pix" in cfg_basename -- cgit v1.2.3 From 6b82efd737827bbeef202f04ff5a8faec9b64ef8 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 27 Jan 2023 20:06:19 -0500 Subject: add v2-inpainting model detection, and broaden v-model detection to include anything with 768 in the name --- modules/sd_models_config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 4d1e92e1..73854a45 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -10,6 +10,7 @@ sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_inpainting = os.path.join(sd_configs_path, "v1-inpainting-inference.yaml") config_instruct_pix2pix = os.path.join(sd_configs_path, "instruct-pix2pix.yaml") @@ -28,7 +29,9 @@ def guess_model_config_from_state_dict(sd, filename): return config_depth_model if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: - if re.search(re_parametrization_v, fn) or "v2-1_768" in fn: + if diffusion_model_input.shape[1] == 9: + return config_sd2_inpainting + elif re.search(re_parametrization_v, fn) or "768" in fn: return config_sd2v else: return config_sd2 -- cgit v1.2.3 From 2aac1d97782b486f3a4a5209cf399dcdcb7bbb4d Mon Sep 17 00:00:00 2001 From: Andrii Skaliuk Date: Fri, 27 Jan 2023 17:32:31 -0800 Subject: Basic inpainting batch support Modifies batch UI to add optional inpainting support --- modules/img2img.py | 20 +++++++++++++++++--- modules/ui.py | 9 ++++++++- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/modules/img2img.py b/modules/img2img.py index 2168c8e2..fe9447c7 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -16,11 +16,16 @@ import modules.images as images import modules.scripts -def process_batch(p, input_dir, output_dir, args): +def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processing.fix_seed(p) images = shared.listfiles(input_dir) + inpaint_masks = shared.listfiles(inpaint_mask_dir) + is_inpaint_batch = inpaint_mask_dir and len(inpaint_masks) > 0 + if is_inpaint_batch: + print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") + print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") save_normally = output_dir == '' @@ -43,6 +48,15 @@ def process_batch(p, input_dir, output_dir, args): img = ImageOps.exif_transpose(img) p.init_images = [img] * p.batch_size + if is_inpaint_batch: + # try to find corresponding mask for an image using simple filename matching + mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image)) + # if not found use first one ("same mask for all images" use-case) + if not mask_image_path in inpaint_masks: + mask_image_path = inpaint_masks[0] + mask_image = Image.open(mask_image_path) + p.image_mask = mask_image + proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: proc = process_images(p) @@ -59,7 +73,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, *args): is_batch = mode == 5 if mode == 0: # img2img @@ -139,7 +153,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, args) + process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args) processed = Processed(p, [], p.seed, "") else: diff --git a/modules/ui.py b/modules/ui.py index 85ae62c7..fddb9177 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -691,9 +691,15 @@ def create_ui(): with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: hidden = '
      Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' - gr.HTML(f"

      Process images in a directory on the same machine where the server is running.
      Use an empty output directory to save pictures normally instead of writing to the output directory.{hidden}

      ") + gr.HTML( + f"

      Process images in a directory on the same machine where the server is running." + + f"
      Use an empty output directory to save pictures normally instead of writing to the output directory." + + f"
      Add inpaint batch mask directory to enable inpaint batch processing." + f"{hidden}

      " + ) img2img_batch_input_dir = gr.Textbox(label="Input directory", **shared.hide_dirs, elem_id="img2img_batch_input_dir") img2img_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, elem_id="img2img_batch_output_dir") + img2img_batch_inpaint_mask_dir = gr.Textbox(label="Inpaint batch mask directory (required for inpaint batch processing only)", **shared.hide_dirs, elem_id="img2img_batch_inpaint_mask_dir") def copy_image(img): if isinstance(img, dict) and 'image' in img: @@ -838,6 +844,7 @@ def create_ui(): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, + img2img_batch_inpaint_mask_dir ] + custom_inputs, outputs=[ img2img_gallery, -- cgit v1.2.3 From 4c52dfe4ac98c53431ecd267d59f27391d3a63e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 08:30:17 +0300 Subject: make the detection for -v models less broad --- modules/sd_models_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 73854a45..00217990 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -31,7 +31,7 @@ def guess_model_config_from_state_dict(sd, filename): if sd2_cond_proj_weight is not None and sd2_cond_proj_weight.shape[1] == 1024: if diffusion_model_input.shape[1] == 9: return config_sd2_inpainting - elif re.search(re_parametrization_v, fn) or "768" in fn: + elif re.search(re_parametrization_v, fn): return config_sd2v else: return config_sd2 -- cgit v1.2.3 From 0834d4ce374225131e025540220c727e352a3e43 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 08:41:15 +0300 Subject: simplify #7284 --- modules/ui.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/modules/ui.py b/modules/ui.py index 3c0a4050..ca2c1eb6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -93,12 +93,6 @@ clear_prompt_symbol = '\U0001F5D1' # 🗑️ extra_networks_symbol = '\U0001F3B4' # 🎴 switch_values_symbol = '\U000021C5' # ⇅ -def switch_width_and_height(width, height): - width_temp = width - width = height - height = width_temp - return width, height - def plaintext_to_html(text): return ui_common.plaintext_to_html(text) @@ -574,7 +568,8 @@ def create_ui(): txt2img_prompt.submit(**txt2img_args) submit.click(**txt2img_args) - res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) + + res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height]) txt_prompt_img.change( fn=modules.images.image_data, @@ -882,7 +877,7 @@ def create_ui(): img2img_prompt.submit(**img2img_args) submit.click(**img2img_args) - res_switch_btn.click(switch_width_and_height, inputs=[width, height], outputs=[width, height]) + res_switch_btn.click(lambda w, h: (h, w), inputs=[width, height], outputs=[width, height]) img2img_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), -- cgit v1.2.3 From 7d1f2a3a495327341ef1b3238347864845799bb6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 10:21:31 +0300 Subject: remove waiting for input on version mismatch warning, change supported versions --- .gitignore | 1 - launch.py | 35 ++++++++++++----------------------- 2 files changed, 12 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index c8be9688..0b1d17ca 100644 --- a/.gitignore +++ b/.gitignore @@ -33,4 +33,3 @@ notification.mp3 /test/stdout.txt /test/stderr.txt /cache.json -no_py_ver_warning diff --git a/launch.py b/launch.py index 4f5a4bc4..7614f9c9 100644 --- a/launch.py +++ b/launch.py @@ -18,35 +18,24 @@ skip_install = False def check_python_version(): - if not os.path.isfile("no_py_ver_warning"): - version = sys.version_info - version_range = None - if platform.system() == "Linux": - version_range = range(7 + 1, 11 + 1) - else: - version_range = range(7 + 1, 10 + 1) + version = sys.version_info + if platform.system() == "Windows": + supported_minors = [10] + else: + supported_minors = [7, 8, 9, 10, 11] + + if not (version.major == 3 and version.minor in supported_minors): + import modules.errors - try: - assert version.major == 3 and version.minor in version_range, f""" -=== Warning === -This program was tested only with 3.10 Python, but you have {version.major}.{version.minor} Python. + modules.errors.print_error_explanation(f""" +This program is tested with 3.10.6 Python, but you have {version.major}.{version.minor}.{version.micro}. If you encounter an error with "RuntimeError: Couldn't install torch." message, or any other error regarding unsuccessful package (library) installation, please downgrade (or upgrade) to the latest version of 3.10 Python and delete current Python and "venv" folder in WebUI's directory. -You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/ - -You will see this warning only once, delete file "no_py_ver_warning" file to show this warning again. -=== Warning === - -Press ENTER to continue...\ -""" - except AssertionError as e: - print(e) - with open("no_py_ver_warning", "w"): - pass - input() +You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/\ +""") def commit_hash(): -- cgit v1.2.3 From 3752aad23d4be4522f9edf3fe79c1122fa5ad509 Mon Sep 17 00:00:00 2001 From: Mackerel Date: Sat, 28 Jan 2023 02:44:12 -0500 Subject: don't replace regular --help with new paths.py parser help --- modules/paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/paths.py b/modules/paths.py index 08e6f9b9..d991cc71 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -6,7 +6,7 @@ import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) # Parse the --data-dir flag first so we can use it as a base for our other argument default values -parser = argparse.ArgumentParser() +parser = argparse.ArgumentParser(add_help=False) parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",) cmd_opts_pre = parser.parse_known_args()[0] data_path = cmd_opts_pre.data_dir -- cgit v1.2.3 From bd52a6d89970cca4f0f8b4275db895c99e173b3f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 10:48:08 +0300 Subject: some more changes for python version warning; add a commandline flag to disable --- launch.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/launch.py b/launch.py index 7614f9c9..370920de 100644 --- a/launch.py +++ b/launch.py @@ -18,23 +18,33 @@ skip_install = False def check_python_version(): - version = sys.version_info - if platform.system() == "Windows": + is_windows = platform.system() == "Windows" + major = sys.version_info.major + minor = sys.version_info.minor + micro = sys.version_info.micro + + if is_windows: supported_minors = [10] else: supported_minors = [7, 8, 9, 10, 11] - if not (version.major == 3 and version.minor in supported_minors): + if not (major == 3 and minor in supported_minors): import modules.errors modules.errors.print_error_explanation(f""" -This program is tested with 3.10.6 Python, but you have {version.major}.{version.minor}.{version.micro}. +INCOMPATIBLE PYTHON VERSION + +This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}. If you encounter an error with "RuntimeError: Couldn't install torch." message, or any other error regarding unsuccessful package (library) installation, please downgrade (or upgrade) to the latest version of 3.10 Python and delete current Python and "venv" folder in WebUI's directory. -You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/\ +You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/ + +{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""} + +Use --skip-python-version-check to suppress this warning. """) @@ -237,6 +247,7 @@ def prepare_environment(): sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') + sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch') sys.argv, update_check = extract_arg(sys.argv, '--update-check') @@ -245,6 +256,9 @@ def prepare_environment(): xformers = '--xformers' in sys.argv ngrok = '--ngrok' in sys.argv + if not skip_python_version_check: + check_python_version() + commit = commit_hash() print(f"Python {sys.version}") @@ -342,6 +356,5 @@ def start(): if __name__ == "__main__": - check_python_version() prepare_environment() start() -- cgit v1.2.3