From be27fd4690b1eb6c74da1e31c9696a0f1901fbba Mon Sep 17 00:00:00 2001 From: evshiron Date: Sun, 30 Oct 2022 17:01:01 +0800 Subject: fix broken progress api by previous rework --- modules/shared.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..2c7d28a5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -4,6 +4,7 @@ import json import os import sys from collections import OrderedDict +import time import gradio as gr import tqdm @@ -132,6 +133,7 @@ class State: current_image = None current_image_sampling_step = 0 textinfo = None + time_start = None def skip(self): self.skipped = True @@ -168,6 +170,7 @@ class State: self.skipped = False self.interrupted = False self.textinfo = None + self.time_start = time.time() devices.torch_gc() -- cgit v1.2.3 From cb31abcf58ea1f64266e6d821937eed058c35f4d Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 21:54:31 +0700 Subject: Settings to select VAE --- modules/sd_models.py | 31 +++++-------- modules/sd_vae.py | 121 +++++++++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 8 ++-- webui.py | 5 +++ 4 files changed, 141 insertions(+), 24 deletions(-) create mode 100644 modules/sd_vae.py (limited to 'modules/shared.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..91ad4b5e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks +from modules import shared, modelloader, devices, script_callbacks, sd_vae from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -160,12 +160,11 @@ def get_state_dict_from_checkpoint(pl_sd): vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - -def load_model_weights(model, checkpoint_info): +def load_model_weights(model, checkpoint_info, force=False): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if checkpoint_info not in checkpoints_loaded: + if force or checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -186,17 +185,7 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - - if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: - vae_file = shared.cmd_opts.vae_path - - if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - model.first_stage_model.load_state_dict(vae_dict) - + sd_vae.load_vae(model, checkpoint_file) model.first_stage_model.to(devices.dtype_vae) if shared.opts.sd_checkpoint_cache > 0: @@ -213,7 +202,7 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(checkpoint_info=None): +def load_model(checkpoint_info=None, force=False): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -234,7 +223,7 @@ def load_model(checkpoint_info=None): do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, force=force) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -252,16 +241,16 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model, info=None): +def reload_model_weights(sd_model, info=None, force=False): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - load_model(checkpoint_info) + load_model(checkpoint_info, force=force) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -271,7 +260,7 @@ def reload_model_weights(sd_model, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, force=force) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) diff --git a/modules/sd_vae.py b/modules/sd_vae.py new file mode 100644 index 00000000..82764e55 --- /dev/null +++ b/modules/sd_vae.py @@ -0,0 +1,121 @@ +import torch +import os +from collections import namedtuple +from modules import shared, devices +from modules.paths import models_path +import glob + +model_dir = "Stable-diffusion" +model_path = os.path.abspath(os.path.join(models_path, model_dir)) +vae_dir = "VAE" +vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) + +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} +default_vae_dict = {"auto": "auto", "None": "None"} +default_vae_list = ["auto", "None"] +default_vae_values = [default_vae_dict[x] for x in default_vae_list] +vae_dict = dict(default_vae_dict) +vae_list = list(default_vae_list) +first_load = True + +def get_filename(filepath): + return os.path.splitext(os.path.basename(filepath))[0] + +def refresh_vae_list(vae_path=vae_path, model_path=model_path): + global vae_dict, vae_list + res = {} + candidates = [ + *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), + *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True) + ] + if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): + candidates.append(shared.cmd_opts.vae_path) + for filepath in candidates: + name = get_filename(filepath) + res[name] = filepath + vae_list.clear() + vae_list.extend(default_vae_list) + vae_list.extend(list(res.keys())) + vae_dict.clear() + vae_dict.update(default_vae_dict) + vae_dict.update(res) + return vae_list + +def load_vae(model, checkpoint_file, vae_file="auto"): + global first_load, vae_dict, vae_list + # save_settings = False + + # if vae_file argument is provided, it takes priority + if vae_file and vae_file not in default_vae_list: + if not os.path.isfile(vae_file): + vae_file = "auto" + # save_settings = True + print("VAE provided as function argument doesn't exist") + # for the first load, if vae-path is provided, it takes priority and failure is reported + if first_load and shared.cmd_opts.vae_path is not None: + if os.path.isfile(shared.cmd_opts.vae_path): + vae_file = shared.cmd_opts.vae_path + # save_settings = True + # print("Using VAE provided as command line argument") + else: + print("VAE provided as command line argument doesn't exist") + # else, we load from settings + if vae_file == "auto" and shared.opts.sd_vae is not None: + # if saved VAE settings isn't recognized, fallback to auto + vae_file = vae_dict.get(shared.opts.sd_vae, "auto") + # if VAE selected but not found, fallback to auto + if vae_file not in default_vae_values and not os.path.isfile(vae_file): + vae_file = "auto" + print("Selected VAE doesn't exist") + # vae-path cmd arg takes priority for auto + if vae_file == "auto" and shared.cmd_opts.vae_path is not None: + if os.path.isfile(shared.cmd_opts.vae_path): + vae_file = shared.cmd_opts.vae_path + print("Using VAE provided as command line argument") + # if still not found, try look for ".vae.pt" beside model + model_path = os.path.splitext(checkpoint_file)[0] + if vae_file == "auto": + vae_file_try = model_path + ".vae.pt" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print("Using VAE found beside selected model") + # if still not found, try look for ".vae.ckpt" beside model + if vae_file == "auto": + vae_file_try = model_path + ".vae.ckpt" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print("Using VAE found beside selected model") + # No more fallbacks for auto + if vae_file == "auto": + vae_file = None + # Last check, just because + if vae_file and not os.path.exists(vae_file): + vae_file = None + + if vae_file: + print(f"Loading VAE weights from: {vae_file}") + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + model.first_stage_model.load_state_dict(vae_dict_1) + + # If vae used is not in dict, update it + # It will be removed on refresh though + if vae_file is not None: + vae_opt = get_filename(vae_file) + if vae_opt not in vae_dict: + vae_dict[vae_opt] = vae_file + vae_list.append(vae_opt) + + """ + # Save current VAE to VAE settings, maybe? will it work? + if save_settings: + if vae_file is None: + vae_opt = "None" + + # shared.opts.sd_vae = vae_opt + """ + + first_load = False + model.first_stage_model.to(devices.dtype_vae) diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..06440ac4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization +from modules import sd_samplers, sd_models, localization, sd_vae from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -295,6 +295,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), @@ -407,11 +408,12 @@ class Options: if bad_settings > 0: print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) - def onchange(self, key, func): + def onchange(self, key, func, call=True): item = self.data_labels.get(key) item.onchange = func - func() + if call: + func() def dumpjson(self): d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} diff --git a/webui.py b/webui.py index 29530872..27949f3d 100644 --- a/webui.py +++ b/webui.py @@ -21,6 +21,7 @@ import modules.paths import modules.scripts import modules.sd_hijack import modules.sd_models +import modules.sd_vae import modules.shared as shared import modules.txt2img @@ -74,8 +75,12 @@ def initialize(): modules.scripts.load_scripts() + modules.sd_vae.refresh_vae_list() modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) + # I don't know what needs to be done to only reload VAE, with all those hijacks callbacks, and lowvram, + # so for now this reloads the whole model too, and no cache + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model, force=True)), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) -- cgit v1.2.3 From d587586d3be2de061238defb8a556f03743287f6 Mon Sep 17 00:00:00 2001 From: mawr Date: Mon, 31 Oct 2022 00:14:07 +0300 Subject: Added "--clip-models-path" switch to avoid using default "~/.cache/clip" and enable to run under unprivileged user without homedir --- modules/interrogate.py | 4 ++-- modules/shared.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/interrogate.py b/modules/interrogate.py index 65b05d34..9769aa34 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -56,9 +56,9 @@ class InterrogateModels: import clip if self.running_on_cpu: - model, preprocess = clip.load(clip_model_name, device="cpu") + model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path) else: - model, preprocess = clip.load(clip_model_name) + model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path) model.eval() model = model.to(devices.device_interrogate) diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..36212031 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -51,6 +51,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) +parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") -- cgit v1.2.3 From 006756f9cd6258eae418e9209cfc13f940ec53e1 Mon Sep 17 00:00:00 2001 From: Fampai <> Date: Mon, 31 Oct 2022 07:26:08 -0400 Subject: Added TI training optimizations option to use xattention optimizations when training option to unload vae when training --- modules/shared.py | 3 ++- modules/textual_inversion/textual_inversion.py | 9 +++++++++ modules/textual_inversion/ui.py | 7 +++++-- 3 files changed, 16 insertions(+), 3 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index fb84afd8..4c3d0ce7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -256,11 +256,12 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), + "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), + "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 17dfb223..b0a1d26b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -214,6 +214,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) + unload = shared.opts.unload_models_when_training if save_embedding_every > 0: embedding_dir = os.path.join(log_directory, "embeddings") @@ -238,6 +239,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + if unload: + shared.sd_model.first_stage_model.to(devices.cpu) hijack = sd_hijack.model_hijack @@ -303,6 +306,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if images_dir is not None and steps_done % create_image_every == 0: forced_filename = f'{embedding_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) + + shared.sd_model.first_stage_model.to(devices.device) + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, do_not_save_grid=True, @@ -330,6 +336,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc processed = processing.process_images(p) image = processed.images[0] + if unload: + shared.sd_model.first_stage_model.to(devices.cpu) + shared.state.current_image = image if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index e712284d..d679e6f4 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -25,8 +25,10 @@ def train_embedding(*args): assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' + apply_optimizations = shared.opts.training_xattention_optimizations try: - sd_hijack.undo_optimizations() + if not apply_optimizations: + sd_hijack.undo_optimizations() embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) @@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)} except Exception: raise finally: - sd_hijack.apply_optimizations() + if not apply_optimizations: + sd_hijack.apply_optimizations() -- cgit v1.2.3 From 910a097ae2ed78a62101951f1b87137f9e1baaea Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 31 Oct 2022 17:36:45 +0300 Subject: add initial version of the extensions tab fix broken Restart Gradio button --- javascript/extensions.js | 24 +++++ modules/extensions.py | 83 +++++++++++++++ modules/generation_parameters_copypaste.py | 5 + modules/scripts.py | 21 +--- modules/shared.py | 10 +- modules/ui.py | 16 ++- modules/ui_extensions.py | 162 +++++++++++++++++++++++++++++ style.css | 22 +++- webui.py | 20 ++-- 9 files changed, 333 insertions(+), 30 deletions(-) create mode 100644 javascript/extensions.js create mode 100644 modules/extensions.py create mode 100644 modules/ui_extensions.py (limited to 'modules/shared.py') diff --git a/javascript/extensions.js b/javascript/extensions.js new file mode 100644 index 00000000..86f5336d --- /dev/null +++ b/javascript/extensions.js @@ -0,0 +1,24 @@ + +function extensions_apply(_, _){ + disable = [] + update = [] + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ + if(x.name.startsWith("enable_") && ! x.checked) + disable.push(x.name.substr(7)) + + if(x.name.startsWith("update_") && x.checked) + update.push(x.name.substr(7)) + }) + + restart_reload() + + return [JSON.stringify(disable), JSON.stringify(update)] +} + +function extensions_check(){ + gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ + x.innerHTML = "Loading..." + }) + + return [] +} \ No newline at end of file diff --git a/modules/extensions.py b/modules/extensions.py new file mode 100644 index 00000000..8d6ae848 --- /dev/null +++ b/modules/extensions.py @@ -0,0 +1,83 @@ +import os +import sys +import traceback + +import git + +from modules import paths, shared + + +extensions = [] +extensions_dir = os.path.join(paths.script_path, "extensions") + + +def active(): + return [x for x in extensions if x.enabled] + + +class Extension: + def __init__(self, name, path, enabled=True): + self.name = name + self.path = path + self.enabled = enabled + self.status = '' + self.can_update = False + + repo = None + try: + if os.path.exists(os.path.join(path, ".git")): + repo = git.Repo(path) + except Exception: + print(f"Error reading github repository info from {path}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + if repo is None or repo.bare: + self.remote = None + else: + self.remote = next(repo.remote().urls, None) + self.status = 'unknown' + + def list_files(self, subdir, extension): + from modules import scripts + + dirpath = os.path.join(self.path, subdir) + if not os.path.isdir(dirpath): + return [] + + res = [] + for filename in sorted(os.listdir(dirpath)): + res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename))) + + res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] + + return res + + def check_updates(self): + repo = git.Repo(self.path) + for fetch in repo.remote().fetch("--dry-run"): + if fetch.flags != fetch.HEAD_UPTODATE: + self.can_update = True + self.status = "behind" + return + + self.can_update = False + self.status = "latest" + + def pull(self): + repo = git.Repo(self.path) + repo.remotes.origin.pull() + + +def list_extensions(): + extensions.clear() + + if not os.path.isdir(extensions_dir): + return + + for dirname in sorted(os.listdir(extensions_dir)): + path = os.path.join(extensions_dir, dirname) + if not os.path.isdir(path): + continue + + extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) + extensions.append(extension) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index df70c728..985ec95e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -17,6 +17,11 @@ paste_fields = {} bind_list = [] +def reset(): + paste_fields.clear() + bind_list.clear() + + def quote(text): if ',' not in str(text): return text diff --git a/modules/scripts.py b/modules/scripts.py index 96e44bfd..533db45c 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,7 @@ import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks +from modules import shared, paths, script_callbacks, extensions AlwaysVisible = object() @@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension): for filename in sorted(os.listdir(basedir)): scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) - extdir = os.path.join(paths.script_path, "extensions") - if os.path.exists(extdir): - for dirname in sorted(os.listdir(extdir)): - dirpath = os.path.join(extdir, dirname) - scriptdirpath = os.path.join(dirpath, scriptdirname) - - if not os.path.isdir(scriptdirpath): - continue - - for filename in sorted(os.listdir(scriptdirpath)): - scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename))) + for ext in extensions.active(): + scripts_list += ext.list_files(scriptdirname, extension) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] @@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension): def list_files_with_name(filename): res = [] - dirs = [paths.script_path] - - extdir = os.path.join(paths.script_path, "extensions") - if os.path.exists(extdir): - dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))] + dirs = [paths.script_path] + [ext.path for ext in extensions.active()] for dirpath in dirs: if not os.path.isdir(dirpath): diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..cce87081 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -132,6 +132,7 @@ class State: current_image = None current_image_sampling_step = 0 textinfo = None + need_restart = False def skip(self): self.skipped = True @@ -354,6 +355,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) +options_templates.update(options_section((None, "Hidden options"), { + "disabled_extensions": OptionInfo([], "Disable those extensions"), +})) + +options_templates.update() + class Options: data = None @@ -365,8 +372,9 @@ class Options: def __setattr__(self, key, value): if self.data is not None: - if key in self.data: + if key in self.data or key in self.data_labels: self.data[key] = value + return return super(Options, self).__setattr__(key, value) diff --git a/modules/ui.py b/modules/ui.py index 5055ca64..2c15abb7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -19,7 +19,7 @@ import numpy as np from PIL import Image, PngImagePlugin -from modules import sd_hijack, sd_models, localization, script_callbacks +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts @@ -671,6 +671,7 @@ def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img + parameters_copypaste.reset() with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) @@ -1511,8 +1512,9 @@ def create_ui(wrap_gradio_gpu_call): column = None with gr.Row(elem_id="settings").style(equal_height=False): for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None - if previous_section != item.section: + if previous_section != item.section and not section_must_be_skipped: if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): if column is not None: column.__exit__() @@ -1531,6 +1533,8 @@ def create_ui(wrap_gradio_gpu_call): if k in quicksettings_names and not shared.cmd_opts.freeze_settings: quicksettings_list.append((i, k, item)) components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) else: component = create_setting_component(k) component_dict[k] = component @@ -1572,9 +1576,10 @@ def create_ui(wrap_gradio_gpu_call): def request_restart(): shared.state.interrupt() - settings_interface.gradio_ref.do_restart = True + shared.state.need_restart = True restart_gradio.click( + fn=request_restart, inputs=[], outputs=[], @@ -1612,14 +1617,15 @@ def create_ui(wrap_gradio_gpu_call): interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings_interface, "Settings", "settings")] + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings"): for i, k, item in quicksettings_list: component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component - settings_interface.gradio_ref = demo - parameters_copypaste.integrate_settings_paste_fields(component_dict) parameters_copypaste.run_bind() diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py new file mode 100644 index 00000000..b7d747dc --- /dev/null +++ b/modules/ui_extensions.py @@ -0,0 +1,162 @@ +import json +import os.path +import shutil +import sys +import time +import traceback + +import git + +import gradio as gr +import html + +from modules import extensions, shared, paths + + +def apply_and_restart(disable_list, update_list): + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + update = json.loads(update_list) + assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}" + + update = set(update) + + for ext in extensions.extensions: + if ext.name not in update: + continue + + try: + ext.pull() + except Exception: + print(f"Error pulling updates for {ext.name}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + shared.opts.disabled_extensions = disabled + shared.opts.save(shared.config_filename) + + shared.state.interrupt() + shared.state.need_restart = True + + +def check_updates(): + for ext in extensions.extensions: + if ext.remote is None: + continue + + try: + ext.check_updates() + except Exception: + print(f"Error checking updates for {ext.name}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + return extension_table() + + +def extension_table(): + code = f""" + + + + + + + + + + """ + + for ext in extensions.extensions: + if ext.can_update: + ext_status = f"""""" + else: + ext_status = ext.status + + code += f""" + + + + {ext_status} + + """ + + code += """ + +
ExtensionURLUpdate
{html.escape(ext.remote or '')}
+ """ + + return code + + +def install_extension_from_url(dirname, url): + assert url, 'No URL specified' + + if dirname is None or dirname == "": + *parts, last_part = url.split('/') + last_part = last_part.replace(".git", "") + + dirname = last_part + + target_dir = os.path.join(extensions.extensions_dir, dirname) + assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' + + assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed' + + tmpdir = os.path.join(paths.script_path, "tmp", dirname) + + try: + shutil.rmtree(tmpdir, True) + + repo = git.Repo.clone_from(url, tmpdir) + repo.remote().fetch() + + os.rename(tmpdir, target_dir) + + extensions.list_extensions() + return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")] + finally: + shutil.rmtree(tmpdir, True) + + +def create_ui(): + import modules.ui + + with gr.Blocks(analytics_enabled=False) as ui: + with gr.Tabs(elem_id="tabs_extensions") as tabs: + with gr.TabItem("Installed"): + extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False) + extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False) + + with gr.Row(): + apply = gr.Button(value="Apply and restart UI", variant="primary") + check = gr.Button(value="Check for updates") + + extensions_table = gr.HTML(lambda: extension_table()) + + apply.click( + fn=apply_and_restart, + _js="extensions_apply", + inputs=[extensions_disabled_list, extensions_update_list], + outputs=[], + ) + + check.click( + fn=check_updates, + _js="extensions_check", + inputs=[], + outputs=[extensions_table], + ) + + with gr.TabItem("Install from URL"): + install_url = gr.Text(label="URL for extension's git repository") + install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") + intall_button = gr.Button(value="Install", variant="primary") + intall_result = gr.HTML(elem_id="extension_install_result") + + intall_button.click( + fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]), + inputs=[install_dirname, install_url], + outputs=[extensions_table, intall_result], + ) + + return ui diff --git a/style.css b/style.css index 8b2211b1..859c3933 100644 --- a/style.css +++ b/style.css @@ -530,6 +530,26 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h min-height: 480px !important; } +/* Extensions */ + +#extensions{ + border-collapse: collapse; +} + +#extensions td, #extensions th{ + border: 1px solid #ccc; + padding: 0.25em 0.5em; +} + +#extensions input[type="checkbox"]{ + margin-right: 0.5em; +} + +#tab_extensions button{ + max-width: 16em; +} + + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running @@ -607,4 +627,4 @@ Then, you will need to add the RTL counterpart only if needed in the rtl section right: unset; left: 0.5em; } -} +} \ No newline at end of file diff --git a/webui.py b/webui.py index 29530872..ad2eb236 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import devices, sd_samplers, upscaler +from modules import devices, sd_samplers, upscaler, extensions import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -60,6 +60,11 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def initialize(): + extensions.list_extensions() + #for ext in extensions.extensions: + # print(ext.name, ext.path, ext.enabled, ext.remote) + #exit() + if cmd_opts.ui_debug_mode: shared.sd_upscalers = upscaler.UpscalerLanczos().scalers modules.scripts.load_scripts() @@ -92,15 +97,18 @@ def create_api(app): api = Api(app, queue_lock) return api + def wait_on_server(demo=None): while 1: time.sleep(0.5) - if demo and getattr(demo, 'do_restart', False): + if shared.state.need_restart: + shared.state.need_restart = False time.sleep(0.5) demo.close() time.sleep(0.5) break + def api_only(): initialize() @@ -132,14 +140,16 @@ def webui(): app.add_middleware(GZipMiddleware, minimum_size=1000) - if (launch_api): + if launch_api: create_api(app) wait_on_server(demo) sd_samplers.set_samplers() - print('Reloading Custom Scripts') + print('Reloading extensions') + extensions.list_extensions() + print('Reloading custom scripts') modules.scripts.reload_scripts() print('Reloading modules: modules.ui') importlib.reload(modules.ui) @@ -148,8 +158,6 @@ def webui(): print('Restarting Gradio') - -task = [] if __name__ == "__main__": if cmd_opts.nowebui: api_only() -- cgit v1.2.3 From dc7425a56e7a014cbfa3b3d44ad2321e519fe378 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 31 Oct 2022 18:33:44 +0300 Subject: disable access to extension stuff for non-local servers --- modules/shared.py | 5 ++++- modules/ui_extensions.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index cce87081..a27c654e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -40,7 +40,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") -parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") +parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) @@ -97,6 +97,9 @@ restricted_opts = { "outdir_save", } +if cmd_opts.share or cmd_opts.listen: + cmd_opts.disable_extension_access = True + devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index b7d747dc..e74b7d68 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -13,7 +13,13 @@ import html from modules import extensions, shared, paths +def check_access(): + assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags" + + def apply_and_restart(disable_list, update_list): + check_access() + disabled = json.loads(disable_list) assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" @@ -40,6 +46,8 @@ def apply_and_restart(disable_list, update_list): def check_updates(): + check_access() + for ext in extensions.extensions: if ext.remote is None: continue @@ -89,6 +97,8 @@ def extension_table(): def install_extension_from_url(dirname, url): + check_access() + assert url, 'No URL specified' if dirname is None or dirname == "": -- cgit v1.2.3 From 9e22a357545c0395c81dd800c72fa18f350545ec Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 31 Oct 2022 18:45:50 +0300 Subject: fix the error with extension tab not working because of the previous commit --- modules/shared.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index a27c654e..c83fb9f5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -97,8 +97,7 @@ restricted_opts = { "outdir_save", } -if cmd_opts.share or cmd_opts.listen: - cmd_opts.disable_extension_access = True +cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) -- cgit v1.2.3 From 4a8cf01f6f7f072cc9c67d6b31662384b212dd9c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 2 Nov 2022 12:12:32 +0300 Subject: remove duplicate code from #3970 --- modules/api/api.py | 10 +--------- modules/shared.py | 14 ++++++++++++++ modules/ui.py | 10 +--------- 3 files changed, 16 insertions(+), 18 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/api/api.py b/modules/api/api.py index b3d85e46..71c9c160 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -178,15 +178,7 @@ class Api: progress = min(progress, 1) - # copy from check_progress_call of ui.py - - if shared.parallel_processing_allowed: - if shared.state.sampling_step - shared.state.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.state.current_latent is not None: - if shared.opts.show_progress_grid: - shared.state.current_image = samples_to_image_grid(shared.state.current_latent) - else: - shared.state.current_image = sample_to_image(shared.state.current_latent) - shared.state.current_image_sampling_step = shared.state.sampling_step + shared.state.set_current_image() current_image = None if shared.state.current_image and not req.skip_current_image: diff --git a/modules/shared.py b/modules/shared.py index 04aaa648..e65f6080 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -184,6 +184,20 @@ class State: devices.torch_gc() + """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" + def set_current_image(self): + if not parallel_processing_allowed: + return + + if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and self.current_latent is not None: + if opts.show_progress_grid: + self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) + else: + self.current_image = sd_samplers.sample_to_image(self.current_latent) + + self.current_image_sampling_step = self.sampling_step + + state = State() artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv')) diff --git a/modules/ui.py b/modules/ui.py index 45cd8c3f..784439ba 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -277,15 +277,7 @@ def check_progress_call(id_part): preview_visibility = gr_show(False) if opts.show_progress_every_n_steps > 0: - if shared.parallel_processing_allowed: - - if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None: - if opts.show_progress_grid: - shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent) - else: - shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent) - shared.state.current_image_sampling_step = shared.state.sampling_step - + shared.state.set_current_image() image = shared.state.current_image if image is None: -- cgit v1.2.3 From 9c67408004ed132637d10321bf44565f82055fd2 Mon Sep 17 00:00:00 2001 From: timntorres <116157310+timntorres@users.noreply.github.com> Date: Wed, 2 Nov 2022 02:18:21 -0700 Subject: Allow saving "before-highres-fix. (#4150) * Save image/s before doing highres fix. --- modules/processing.py | 17 +++++++++++++++-- modules/sd_samplers.py | 5 ++--- modules/shared.py | 1 + 3 files changed, 18 insertions(+), 5 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/processing.py b/modules/processing.py index b541ee2b..2dcf4879 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -521,7 +521,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.autocast(): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + # Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix. + if isinstance(p, StableDiffusionProcessingTxt2Img): + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, n=n) + else: + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) samples_ddim = samples_ddim.to(devices.dtype_vae) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) @@ -649,7 +653,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, n=0): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: @@ -685,6 +689,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) + # Save a copy of the image/s before doing highres fix, if applicable. + if opts.save and not self.do_not_save_samples and opts.save_images_before_highres_fix: + for i in range(self.batch_size): + # This batch's ith image. + img = sd_samplers.sample_to_image(samples, i) + # Index that accounts for both batch size and batch count. + ind = i + self.batch_size*n + images.save_image(img, self.outpath_samples, "", self.all_seeds[ind], self.all_prompts[ind], opts.samples_format, suffix=f"-before-highres-fix") + shared.state.nextjob() self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 44d4c189..d7fa89a0 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -93,9 +93,8 @@ def single_sample_to_image(sample): return Image.fromarray(x_sample) -def sample_to_image(samples): - return single_sample_to_image(samples[0]) - +def sample_to_image(samples, index=0): + return single_sample_to_image(samples[index]) def samples_to_image_grid(samples): return images.image_grid([single_sample_to_image(sample) for sample in samples]) diff --git a/modules/shared.py b/modules/shared.py index e65f6080..ce991424 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -255,6 +255,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), + "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), -- cgit v1.2.3 From eb5e82c7ddf5e72fa13b83bd1f12d3a07a4de1a4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 2 Nov 2022 12:45:03 +0300 Subject: do not unnecessarily run VAE one more time when saving intermediate image with hires fix --- modules/processing.py | 39 ++++++++++++++++++++------------------- modules/sd_samplers.py | 1 + modules/shared.py | 2 +- scripts/img2imgalt.py | 3 +-- 4 files changed, 23 insertions(+), 22 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/processing.py b/modules/processing.py index 2dcf4879..3a364b5f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -199,7 +199,7 @@ class StableDiffusionProcessing(): def init(self, all_prompts, all_seeds, all_subseeds): pass - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): raise NotImplementedError() def close(self): @@ -521,11 +521,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: shared.state.job = f"Batch {n+1} out of {p.n_iter}" with devices.autocast(): - # Only Txt2Img needs an extra argument, n, when saving intermediate images pre highres fix. - if isinstance(p, StableDiffusionProcessingTxt2Img): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, n=n) - else: - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) samples_ddim = samples_ddim.to(devices.dtype_vae) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) @@ -653,7 +649,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, n=0): + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: @@ -666,9 +662,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] + """saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images""" + def save_intermediate(image, index): + if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: + return + + if not isinstance(image, Image.Image): + image = sd_samplers.sample_to_image(image, index) + + images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") + if opts.use_scale_latent_for_hires_fix: samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + for i in range(samples.shape[0]): + save_intermediate(samples, i) else: decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) @@ -678,6 +686,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) image = Image.fromarray(x_sample) + + save_intermediate(image, i) + image = images.resize_image(0, image, self.width, self.height) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) @@ -689,15 +700,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) - # Save a copy of the image/s before doing highres fix, if applicable. - if opts.save and not self.do_not_save_samples and opts.save_images_before_highres_fix: - for i in range(self.batch_size): - # This batch's ith image. - img = sd_samplers.sample_to_image(samples, i) - # Index that accounts for both batch size and batch count. - ind = i + self.batch_size*n - images.save_image(img, self.outpath_samples, "", self.all_seeds[ind], self.all_prompts[ind], opts.samples_format, suffix=f"-before-highres-fix") - shared.state.nextjob() self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) @@ -844,8 +846,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask) - - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) @@ -856,4 +857,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): del x devices.torch_gc() - return samples \ No newline at end of file + return samples diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d7fa89a0..c7c414ef 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -96,6 +96,7 @@ def single_sample_to_image(sample): def sample_to_image(samples, index=0): return single_sample_to_image(samples[index]) + def samples_to_image_grid(samples): return images.image_grid([single_sample_to_image(sample) for sample in samples]) diff --git a/modules/shared.py b/modules/shared.py index ce991424..01f47e38 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -256,6 +256,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), + "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), @@ -322,7 +323,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), - "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 88abc093..964b75c7 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -166,8 +166,7 @@ class Script(scripts.Script): if override_strength: p.denoising_strength = 1.0 - - def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): lat = (p.init_latent.cpu().numpy() * 10).astype(int) same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \ -- cgit v1.2.3