From 9fd457e21d6c809a69a1318f03d75f7b3e09b865 Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Thu, 15 Dec 2022 21:57:48 +0300 Subject: allow_credentials and allow_headers for api from https://fastapi.tiangolo.com/tutorial/cors/ --- webui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index c2d0c6be..13a4d14a 100644 --- a/webui.py +++ b/webui.py @@ -90,11 +90,11 @@ def initialize(): def setup_cors(app): if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex: - app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) elif cmd_opts.cors_allow_origins: - app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'], allow_credentials=True, allow_headers=['*']) elif cmd_opts.cors_allow_origins_regex: - app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*']) + app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'], allow_credentials=True, allow_headers=['*']) def create_api(app): -- cgit v1.2.3 From 2d5a5076bb2a0c05cc27d75a1bcadab7f32a46d0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 18:38:21 +0300 Subject: Make it so that upscalers are not repeated when restarting UI. --- modules/modelloader.py | 20 ++++++++++++++++++++ webui.py | 14 +++++++------- 2 files changed, 27 insertions(+), 7 deletions(-) (limited to 'webui.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index e647f6fa..6a1a7ac8 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): pass +builtin_upscaler_classes = [] +forbidden_upscaler_classes = set() + + +def list_builtin_upscalers(): + load_upscalers() + + builtin_upscaler_classes.clear() + builtin_upscaler_classes.extend(Upscaler.__subclasses__()) + + +def forbid_loaded_nonbuiltin_upscalers(): + for cls in Upscaler.__subclasses__(): + if cls not in builtin_upscaler_classes: + forbidden_upscaler_classes.add(cls) + + def load_upscalers(): # We can only do this 'magic' method to dynamically load upscalers if they are referenced, # so we'll try to import any _model.py files before looking in __subclasses__ @@ -139,6 +156,9 @@ def load_upscalers(): datas = [] commandline_options = vars(shared.cmd_opts) for cls in Upscaler.__subclasses__(): + if cls in forbidden_upscaler_classes: + continue + name = cls.__name__ cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" scaler = cls(commandline_options.get(cmd_name, None)) diff --git a/webui.py b/webui.py index 3aee8792..c7d55a97 100644 --- a/webui.py +++ b/webui.py @@ -1,4 +1,5 @@ import os +import sys import threading import time import importlib @@ -55,8 +56,8 @@ def initialize(): gfpgan.setup_model(cmd_opts.gfpgan_models_path) shared.face_restorers.append(modules.face_restoration.FaceRestoration()) + modelloader.list_builtin_upscalers() modules.scripts.load_scripts() - modelloader.load_upscalers() modules.sd_vae.refresh_vae_list() @@ -169,23 +170,22 @@ def webui(): modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo) + print('Restarting UI...') sd_samplers.set_samplers() - print('Reloading extensions') extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) - print('Reloading custom scripts') + modelloader.forbid_loaded_nonbuiltin_upscalers() modules.scripts.reload_scripts() modelloader.load_upscalers() - print('Reloading modules: modules.ui') - importlib.reload(modules.ui) - print('Refreshing Model List') + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + modules.sd_models.list_models() - print('Restarting Gradio') if __name__ == "__main__": -- cgit v1.2.3 From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:35:07 +0300 Subject: helpful error message when trying to load 2.0 without config failing to load model weights from settings won't break generation for currently loaded model anymore --- modules/errors.py | 25 +++++++++++++++++++++++-- modules/sd_models.py | 26 ++++++++++++++++++-------- modules/shared.py | 9 +++++++-- webui.py | 12 ++++++++++-- 4 files changed, 58 insertions(+), 14 deletions(-) (limited to 'webui.py') diff --git a/modules/errors.py b/modules/errors.py index 372dc51a..a668c014 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -2,9 +2,30 @@ import sys import traceback +def print_error_explanation(message): + lines = message.strip().split("\n") + max_len = max([len(x) for x in lines]) + + print('=' * max_len, file=sys.stderr) + for line in lines: + print(line, file=sys.stderr) + print('=' * max_len, file=sys.stderr) + + +def display(e: Exception, task): + print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + message = str(e) + if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: + print_error_explanation(""" +The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file. +See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this. + """) + + def run(code, task): try: code() except Exception as e: - print(f"{task}: {type(e).__name__}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + display(task, e) diff --git a/modules/sd_models.py b/modules/sd_models.py index b98b05fc..6846b74a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -278,6 +278,7 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -312,6 +313,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.use_fp16 = False sd_model = instantiate_from_config(sd_config.model) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -336,10 +338,12 @@ def load_model(checkpoint_info=None): def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - + if not sd_model: sd_model = shared.sd_model + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return @@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) - - sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) - - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: - sd_model.to(devices.device) + try: + load_model_weights(sd_model, checkpoint_info) + except Exception as e: + print("Failed to load checkpoint, restoring previous") + load_model_weights(sd_model, current_checkpoint_info) + raise + finally: + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) print("Weights loaded.") + return sd_model diff --git a/modules/shared.py b/modules/shared.py index 23657a93..7588c47b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading +from modules import localization, sd_vae, extensions, script_loading, errors from modules.paths import models_path, script_path, sd_path @@ -494,7 +494,12 @@ class Options: return False if self.data_labels[key].onchange is not None: - self.data_labels[key].onchange() + try: + self.data_labels[key].onchange() + except Exception as e: + errors.display(e, f"changing setting {key} to {value}") + setattr(self, key, oldval) + return False return True diff --git a/webui.py b/webui.py index c7d55a97..13375e71 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook +from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path @@ -61,7 +61,15 @@ def initialize(): modelloader.load_upscalers() modules.sd_vae.refresh_vae_list() - modules.sd_models.load_model() + + try: + modules.sd_models.load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load, exiting", file=sys.stderr) + exit(1) + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) -- cgit v1.2.3 From 8111b5569d07c7ac3b695e28171aede728b4ae56 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 3 Jan 2023 20:43:05 -0500 Subject: Add support for PyTorch nightly and local builds --- modules/devices.py | 28 +++++++++++++++++++++++----- webui.py | 7 ++++++- 2 files changed, 29 insertions(+), 6 deletions(-) (limited to 'webui.py') diff --git a/modules/devices.py b/modules/devices.py index 800510b7..caeb0276 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs): return orig_tensor_numpy(self, *args, **kwargs) -# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working -if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): - torch.Tensor.to = tensor_to_fix - torch.nn.functional.layer_norm = layer_norm_fix - torch.Tensor.numpy = numpy_fix +# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +orig_cumsum = torch.cumsum +orig_Tensor_cumsum = torch.Tensor.cumsum +def cumsum_fix(input, cumsum_func, *args, **kwargs): + if input.device.type == 'mps': + output_dtype = kwargs.get('dtype', input.dtype) + if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]): + return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + return cumsum_func(input, *args, **kwargs) + + +if has_mps(): + if version.parse(torch.__version__) < version.parse("1.13"): + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix + elif version.parse(torch.__version__) > version.parse("1.13.1"): + if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)): + torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + orig_narrow = torch.narrow + torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) diff --git a/webui.py b/webui.py index 13375e71..ddfaea95 100644 --- a/webui.py +++ b/webui.py @@ -4,7 +4,7 @@ import threading import time import importlib import signal -import threading +import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -13,6 +13,11 @@ from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path +import torch +# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors +if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0) + from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer import modules.extras -- cgit v1.2.3 From 65ed4421e609dda3112f236c13e4db14caa71364 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 13:55:50 +0300 Subject: add callback for when the script is unloaded --- modules/script_callbacks.py | 18 +++++++++++++++++- webui.py | 2 ++ 2 files changed, 19 insertions(+), 1 deletion(-) (limited to 'webui.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index de69fd9f..608c5300 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_script_unloaded=[], ) @@ -171,6 +172,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def script_unloaded_callback(): + for c in reversed(callback_map['callbacks_script_unloaded']): + try: + c.callback() + except Exception: + report_exception(c, 'script_unloaded') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -202,7 +211,7 @@ def on_app_started(callback): def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is - passed as an argument""" + passed as an argument; this function is also called when the script is reloaded. """ add_callback(callback_map['callbacks_model_loaded'], callback) @@ -279,3 +288,10 @@ def on_image_grid(callback): - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ add_callback(callback_map['callbacks_image_grid'], callback) + + +def on_script_unloaded(callback): + """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that + the script did should be reverted here""" + + add_callback(callback_map['callbacks_script_unloaded'], callback) diff --git a/webui.py b/webui.py index ff6eb6eb..733a06b5 100644 --- a/webui.py +++ b/webui.py @@ -187,12 +187,14 @@ def webui(): sd_samplers.set_samplers() + modules.script_callbacks.script_unloaded_callback() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) modelloader.forbid_loaded_nonbuiltin_upscalers() modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) modelloader.load_upscalers() for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: -- cgit v1.2.3 From 5e6566324bba20554bcc04f3dda798e560397f38 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 07:06:26 -0500 Subject: Always end version number with a digit --- webui.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 733a06b5..8737e593 100644 --- a/webui.py +++ b/webui.py @@ -16,7 +16,7 @@ from modules.paths import script_path import torch # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: - torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0) + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir import modules.codeformer_model as codeformer -- cgit v1.2.3 From 1fbb6f9ebe48326a3b12ecf611105dbc4a46891e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 23:35:40 +0300 Subject: make a dropdown for prompt template selection --- modules/hypernetworks/hypernetwork.py | 7 ++++-- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 35 ++++++++++++++++++++------ modules/ui.py | 11 ++++++-- webui.py | 3 +++ 5 files changed, 45 insertions(+), 12 deletions(-) (limited to 'webui.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 32c67ccc..ea3f1db9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,6 +24,7 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} + class HypernetworkModule(torch.nn.Module): multiplier = 1.0 activation_dict = { @@ -403,13 +404,15 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 - textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + template_file = textual_inversion.textual_inversion_templates.get(template_filename, None) + textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork") + template_file = template_file.path path = shared.hypernetworks.get(hypernetwork_name, None) shared.loaded_hypernetwork = Hypernetwork() diff --git a/modules/shared.py b/modules/shared.py index a1e10201..aa37c8ce 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -33,6 +33,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--textual-inversion-templates-dir", type=str, default=os.path.join(script_path, 'textual_inversion_templates'), help="directory with textual inversion templates") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 14be2c96..5420903f 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -2,6 +2,7 @@ import os import sys import traceback import inspect +from collections import namedtuple import torch import tqdm @@ -15,12 +16,26 @@ from modules import shared, devices, sd_hijack, processing, sd_models, images, s import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler -from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, - insert_image_data_embed, extract_image_data_embed, - caption_image_overlay) +from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay from modules.textual_inversion.logging import save_settings_to_file +TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"]) +textual_inversion_templates = {} + + +def list_textual_inversion_templates(): + textual_inversion_templates.clear() + + for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir): + for fn in fns: + path = os.path.join(root, fn) + + textual_inversion_templates[fn] = TextualInversionTemplate(fn, path) + + return textual_inversion_templates + + class Embedding: def __init__(self, vec, name, step=None): self.vec = vec @@ -274,7 +289,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): }) -def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): +def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" assert isinstance(batch_size, int), "Batch size must be integer" @@ -284,8 +299,9 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert data_root, "Dataset directory is empty" assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - assert template_file, "Prompt template file is empty" - assert os.path.isfile(template_file), "Prompt template file doesn't exist" + assert template_filename, "Prompt template file not selected" + assert template_file, f"Prompt template file {template_filename} not found" + assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" assert steps > 0, "Max steps must be positive" @@ -296,10 +312,13 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 - validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = textual_inversion_templates.get(template_filename, None) + validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding") + template_file = template_file.path shared.state.job = "train-embedding" shared.state.textinfo = "Initializing textual inversion training..." diff --git a/modules/ui.py b/modules/ui.py index ddfe1b1a..b6079aec 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -37,7 +37,7 @@ from modules import prompt_parser from modules.images import save_image from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img -import modules.textual_inversion.ui +from modules.textual_inversion import textual_inversion import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text @@ -1322,6 +1322,9 @@ def create_ui(): outputs=[process_focal_crop_row], ) + def get_textual_inversion_template_names(): + return sorted([x for x in textual_inversion.textual_inversion_templates]) + with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images [wiki]

") with FormRow(): @@ -1345,7 +1348,11 @@ def create_ui(): dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion", elem_id="train_log_directory") - template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") + + with FormRow(): + template_file = gr.Dropdown(label='Prompt template', value="style_filewords.txt", elem_id="train_template_file", choices=get_textual_inversion_template_names()) + create_refresh_button(template_file, textual_inversion.list_textual_inversion_templates, lambda: {"choices": get_textual_inversion_template_names()}, "refrsh_train_template_file") + training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") diff --git a/webui.py b/webui.py index 8737e593..47d372c7 100644 --- a/webui.py +++ b/webui.py @@ -33,6 +33,7 @@ import modules.sd_models import modules.sd_vae import modules.txt2img import modules.script_callbacks +import modules.textual_inversion.textual_inversion import modules.ui from modules import modelloader @@ -67,6 +68,8 @@ def initialize(): modules.sd_vae.refresh_vae_list() + modules.textual_inversion.textual_inversion.list_textual_inversion_templates() + try: modules.sd_models.load_model() except Exception as e: -- cgit v1.2.3 From a95f1353089bdeaccd7c266b40cdd79efedfe632 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 09:56:59 +0300 Subject: change hash to sha256 --- .gitignore | 1 + modules/api/api.py | 2 +- modules/api/models.py | 3 +- modules/hashes.py | 72 +++++++++++++++ modules/hypernetworks/hypernetwork.py | 4 +- modules/sd_models.py | 116 ++++++++++++++++--------- modules/shared.py | 2 +- modules/textual_inversion/textual_inversion.py | 6 +- webui.py | 2 + 9 files changed, 158 insertions(+), 50 deletions(-) create mode 100644 modules/hashes.py (limited to 'webui.py') diff --git a/.gitignore b/.gitignore index 21fa26a7..0b1d17ca 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,4 @@ notification.mp3 /extensions /test/stdout.txt /test/stderr.txt +/cache.json diff --git a/modules/api/api.py b/modules/api/api.py index 5767ba90..9814bbc2 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -371,7 +371,7 @@ class Api: return upscalers def get_sd_models(self): - return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] + return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config(x)} for x in checkpoints_list.values()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/api/models.py b/modules/api/models.py index c78095ca..1eb1fcf1 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -224,7 +224,8 @@ class UpscalerItem(BaseModel): class SDModelItem(BaseModel): title: str = Field(title="Title") model_name: str = Field(title="Model Name") - hash: str = Field(title="Hash") + hash: Optional[str] = Field(title="Short hash") + sha256: Optional[str] = Field(title="sha256 hash") filename: str = Field(title="Filename") config: str = Field(title="Config file") diff --git a/modules/hashes.py b/modules/hashes.py new file mode 100644 index 00000000..ebfbd90c --- /dev/null +++ b/modules/hashes.py @@ -0,0 +1,72 @@ +import hashlib +import json +import os.path + +import filelock + + +cache_filename = "cache.json" +cache_data = None + + +def dump_cache(): + with filelock.FileLock(cache_filename+".lock"): + with open(cache_filename, "w", encoding="utf8") as file: + json.dump(cache_data, file, indent=4) + + +def cache(subsection): + global cache_data + + if cache_data is None: + with filelock.FileLock(cache_filename+".lock"): + if not os.path.isfile(cache_filename): + cache_data = {} + else: + with open(cache_filename, "r", encoding="utf8") as file: + cache_data = json.load(file) + + s = cache_data.get(subsection, {}) + cache_data[subsection] = s + + return s + + +def calculate_sha256(filename): + hash_sha256 = hashlib.sha256() + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def sha256(filename, title): + hashes = cache("hashes") + ondisk_mtime = os.path.getmtime(filename) + + if title in hashes: + cached_sha256 = hashes[title].get("sha256", None) + cached_mtime = hashes[title].get("mtime", 0) + + if ondisk_mtime <= cached_mtime and cached_sha256 is not None: + return cached_sha256 + + print(f"Calculating sha256 for {filename}: ", end='') + sha256_value = calculate_sha256(filename) + print(f"{sha256_value}") + + hashes[title] = { + "mtime": ondisk_mtime, + "sha256": sha256_value, + } + + dump_cache() + + return sha256_value + + + + + diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 83cbb4f0..9b5f2e79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -509,7 +509,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if shared.opts.save_training_settings_to_txt: saved_params = dict( - model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), + model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]} ) logging.save_settings_to_file(log_directory, {**saved_params, **locals()}) @@ -737,7 +737,7 @@ def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None try: - hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint = checkpoint.shorthash hypernetwork.sd_checkpoint_name = checkpoint.model_name hypernetwork.name = hypernetwork_name hypernetwork.save(filename) diff --git a/modules/sd_models.py b/modules/sd_models.py index c466f273..7babb9ae 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,17 +14,56 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} +checkpoint_alisases = {} checkpoints_loaded = collections.OrderedDict() + +class CheckpointInfo: + def __init__(self, filename): + self.filename = filename + abspath = os.path.abspath(filename) + + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): + name = abspath.replace(shared.cmd_opts.ckpt_dir, '') + elif abspath.startswith(model_path): + name = abspath.replace(model_path, '') + else: + name = os.path.basename(filename) + + if name.startswith("\\") or name.startswith("/"): + name = name[1:] + + self.title = name + self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + self.hash = model_hash(filename) + self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + self.shorthash = None + self.sha256 = None + + def register(self): + checkpoints_list[self.title] = self + for id in self.ids: + checkpoint_alisases[id] = self + + def calculate_shorthash(self): + self.sha256 = hashes.sha256(self.filename, self.title) + self.shorthash = self.sha256[0:10] + + if self.shorthash not in self.ids: + self.ids += [self.shorthash, self.sha256] + self.register() + + return self.shorthash + + try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -43,10 +82,14 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - convert = lambda name: int(name) if name.isdigit() else name.lower() - alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] - return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) +def checkpoint_tiles(): + def convert(name): + return int(name) if name.isdigit() else name.lower() + + def alphanumeric_key(key): + return [convert(c) for c in re.split('([0-9]+)', key)] + + return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) def find_checkpoint_config(info): @@ -62,48 +105,38 @@ def find_checkpoint_config(info): def list_models(): checkpoints_list.clear() + checkpoint_alisases.clear() model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"]) - def modeltitle(path, shorthash): - abspath = os.path.abspath(path) - - if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): - name = abspath.replace(shared.cmd_opts.ckpt_dir, '') - elif abspath.startswith(model_path): - name = abspath.replace(model_path, '') - else: - name = os.path.basename(path) - - if name.startswith("\\") or name.startswith("/"): - name = name[1:] - - shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] - - return f'{name} [{shorthash}]', shortname - cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): - h = model_hash(cmd_ckpt) - title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) - shared.opts.data['sd_model_checkpoint'] = title + checkpoint_info = CheckpointInfo(cmd_ckpt) + checkpoint_info.register() + + shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) + for filename in model_list: - h = model_hash(filename) - title, short_model_name = modeltitle(filename, h) + checkpoint_info = CheckpointInfo(filename) + checkpoint_info.register() + - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) +def get_closet_checkpoint_match(search_string): + checkpoint_info = checkpoint_alisases.get(search_string, None) + if checkpoint_info is not None: + return + found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) + if found: + return found[0] -def get_closet_checkpoint_match(searchString): - applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title)) - if len(applicable) > 0: - return applicable[0] return None def model_hash(filename): + """old hash that only looks at a small part of the file and is prone to collisions""" + try: with open(filename, "rb") as file: import hashlib @@ -119,7 +152,7 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint - checkpoint_info = checkpoints_list.get(model_checkpoint, None) + checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info @@ -189,9 +222,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info, vae_file="auto"): - checkpoint_file = checkpoint_info.filename - sd_model_hash = checkpoint_info.hash +def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): + sd_model_hash = checkpoint_info.calculate_shorthash() cache_enabled = shared.opts.sd_checkpoint_cache > 0 @@ -201,9 +233,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.load_state_dict(checkpoints_loaded[checkpoint_info]) else: # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") - sd = read_state_dict(checkpoint_file) + sd = read_state_dict(checkpoint_info.filename) model.load_state_dict(sd, strict=False) del sd @@ -235,14 +267,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoints_loaded.popitem(last=False) # LRU model.sd_model_hash = sd_model_hash - model.sd_model_checkpoint = checkpoint_file + model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file) sd_vae.load_vae(model, vae_file) diff --git a/modules/shared.py b/modules/shared.py index b90ded52..d74c069d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,7 +428,7 @@ options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), + "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6939efcc..63935878 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -407,7 +407,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: - save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) + save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) latent_sampling_method = ds.latent_sampling_method @@ -584,7 +584,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name - footer_mid = '[{}]'.format(checkpoint.hash) + footer_mid = '[{}]'.format(checkpoint.shorthash) footer_right = '{}v {}s'.format(vectorSize, steps_done) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) @@ -626,7 +626,7 @@ def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, r old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None try: - embedding.sd_checkpoint = checkpoint.hash + embedding.sd_checkpoint = checkpoint.shorthash embedding.sd_checkpoint_name = checkpoint.model_name if remove_cached_checksum: embedding.cached_checksum = None diff --git a/webui.py b/webui.py index 47d372c7..1fff80da 100644 --- a/webui.py +++ b/webui.py @@ -78,6 +78,8 @@ def initialize(): print("Stable diffusion model failed to load, exiting", file=sys.stderr) exit(1) + shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) -- cgit v1.2.3 From 6eb72fd13f34d94d5459290dd1a0bf0e9ddeda82 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 13:38:10 +0300 Subject: bump gradio to 3.16.1 --- modules/ui.py | 13 ++++++----- requirements.txt | 2 +- requirements_versions.txt | 2 +- style.css | 57 ++++++++++++++++++++++++++++++++--------------- webui.py | 3 +-- 5 files changed, 49 insertions(+), 28 deletions(-) (limited to 'webui.py') diff --git a/modules/ui.py b/modules/ui.py index e86a624b..202e84e5 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -605,7 +605,7 @@ def create_ui(): setup_progressbar(progressbar, txt2img_preview, 'txt2img') with gr.Row().style(equal_height=False): - with gr.Column(variant='panel', elem_id="txt2img_settings"): + with gr.Column(variant='compact', elem_id="txt2img_settings"): for category in ordered_ui_categories(): if category == "sampler": steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") @@ -794,7 +794,7 @@ def create_ui(): setup_progressbar(progressbar, img2img_preview, 'img2img') with FormRow().style(equal_height=False): - with gr.Column(variant='panel', elem_id="img2img_settings"): + with gr.Column(variant='compact', elem_id="img2img_settings"): with gr.Tabs(elem_id="mode_img2img"): with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool="editor", image_mode="RGBA").style(height=480) @@ -1026,7 +1026,7 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): + with gr.Column(variant='compact'): with gr.Tabs(elem_id="mode_extras"): with gr.TabItem('Single Image', elem_id="extras_single_tab"): extras_image = gr.Image(label="Source", source="upload", interactive=True, type="pil", elem_id="extras_image") @@ -1127,8 +1127,8 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as modelmerger_interface: with gr.Row().style(equal_height=False): - with gr.Column(variant='panel'): - gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") + with gr.Column(variant='compact'): + gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") with FormRow(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") @@ -1150,7 +1150,8 @@ def create_ui(): config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') + with gr.Row(): + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) diff --git a/requirements.txt b/requirements.txt index e1dbf8e5..6cdea781 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,7 +5,7 @@ fairscale==0.4.4 fonts font-roboto gfpgan -gradio==3.15.0 +gradio==3.16.1 invisible-watermark numpy omegaconf diff --git a/requirements_versions.txt b/requirements_versions.txt index d2899292..cc06d2b4 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -3,7 +3,7 @@ transformers==4.19.2 accelerate==0.12.0 basicsr==1.4.2 gfpgan==1.3.8 -gradio==3.15.0 +gradio==3.16.1 numpy==1.23.3 Pillow==9.4.0 realesrgan==0.3.0 diff --git a/style.css b/style.css index ffd6307f..14b15191 100644 --- a/style.css +++ b/style.css @@ -20,7 +20,7 @@ padding-right: 0.25em; margin: 0.1em 0; opacity: 0%; - cursor: default; + cursor: default; } .output-html p {margin: 0 0.5em;} @@ -221,7 +221,10 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ background-color: rgb(31, 41, 55); - box-shadow: 6px 0 6px 0px rgb(31, 41, 55), -6px 0 6px 0px rgb(31, 41, 55); + box-shadow: none; + border: 1px solid rgba(128, 128, 128, 0.1); + border-radius: 6px; + padding: 0.1em 0.5em; } #txt2img_column_batch, #img2img_column_batch{ @@ -371,7 +374,7 @@ input[type="range"]{ grid-area: tile; } -.modalClose, +.modalClose, .modalZoom, .modalTileImage { color: white; @@ -509,29 +512,20 @@ input[type="range"]{ } #quicksettings > div{ - border: none; - background: none; - flex: unset; - gap: 1em; -} - -#quicksettings > div > div{ - max-width: 32em; + max-width: 24em; min-width: 24em; padding: 0; + border: none; + box-shadow: none; + background: none; } -#quicksettings > div > div > div > div > label > span { +#quicksettings > div > div > div > label > span { position: relative; margin-right: 9em; margin-bottom: -1em; } -#quicksettings > div > div > label > span { - position: relative; - margin-bottom: -1em; -} - canvas[key="mask"] { z-index: 12 !important; filter: invert(); @@ -666,7 +660,10 @@ footer { font-weight: bold; } -#txt2img_checkboxes > div > div{ +#txt2img_checkboxes, #img2img_checkboxes{ + margin-bottom: 0.5em; +} +#txt2img_checkboxes > div > div, #img2img_checkboxes > div > div{ flex: 0; white-space: nowrap; min-width: auto; @@ -676,6 +673,30 @@ footer { opacity: 0.5; } +.gr-compact { + border: none; + padding-top: 1em; +} + +.dark .gr-compact{ + background-color: rgb(31 41 55 / var(--tw-bg-opacity)); + margin-left: 0.8em; +} + +.gr-compact > *{ + margin-top: 0.5em !important; +} + +.gr-compact .gr-block, .gr-compact .gr-form{ + border: none; + box-shadow: none; +} + +.gr-compact .gr-box{ + border-radius: .5rem !important; + border-width: 1px !important; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running diff --git a/webui.py b/webui.py index 1fff80da..84159515 100644 --- a/webui.py +++ b/webui.py @@ -157,7 +157,7 @@ def webui(): shared.demo = modules.ui.create_ui() - app, local_url, share_url = shared.demo.queue(default_enabled=False).launch( + app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=server_name, server_port=cmd_opts.port, @@ -185,7 +185,6 @@ def webui(): create_api(app) modules.script_callbacks.app_started_callback(shared.demo, app) - modules.script_callbacks.app_started_callback(shared.demo, app) wait_on_server(shared.demo) print('Restarting UI...') -- cgit v1.2.3 From d8b90ac121cbf0c18b1dc9d56a5e1d14ca51e74e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 18:50:56 +0300 Subject: big rework of progressbar/preview system to allow multiple users to prompts at the same time and do not get previews of each other --- javascript/progressbar.js | 249 ++++++++++++++++--------- javascript/textualInversion.js | 13 +- javascript/ui.js | 33 +++- modules/call_queue.py | 19 +- modules/hypernetworks/hypernetwork.py | 6 +- modules/img2img.py | 2 +- modules/progress.py | 96 ++++++++++ modules/sd_samplers.py | 2 +- modules/shared.py | 16 +- modules/textual_inversion/preprocess.py | 2 +- modules/textual_inversion/textual_inversion.py | 6 +- modules/txt2img.py | 2 +- modules/ui.py | 41 ++-- modules/ui_progress.py | 101 ---------- style.css | 74 +++++--- webui.py | 3 + 16 files changed, 390 insertions(+), 275 deletions(-) create mode 100644 modules/progress.py delete mode 100644 modules/ui_progress.py (limited to 'webui.py') diff --git a/javascript/progressbar.js b/javascript/progressbar.js index d6323ed9..b7524ef7 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,82 +1,25 @@ // code related to showing and updating progressbar shown as the image is being made -global_progressbars = {} -galleries = {} -galleryObservers = {} - -// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running -timeoutIds = {} -function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ - // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id - // every time you use gr.HTML(elem_id='xxx'), so we handle this here - var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar) - var progressbarParent - if(progressbar){ - progressbarParent = gradioApp().querySelector("#"+id_progressbar) - } else{ - progressbar = gradioApp().getElementById(id_progressbar) - progressbarParent = null - } - var skip = id_skip ? gradioApp().getElementById(id_skip) : null - var interrupt = gradioApp().getElementById(id_interrupt) - - if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ - if(progressbar.innerText){ - let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion'; - if(document.title != newtitle){ - document.title = newtitle; - } - }else{ - let newtitle = 'Stable Diffusion' - if(document.title != newtitle){ - document.title = newtitle; - } - } - } - - if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ - global_progressbars[id_progressbar] = progressbar - - var mutationObserver = new MutationObserver(function(m){ - if(timeoutIds[id_part]) return; - - preview = gradioApp().getElementById(id_preview) - gallery = gradioApp().getElementById(id_gallery) +galleries = {} +storedGallerySelections = {} +galleryObservers = {} - if(preview != null && gallery != null){ - preview.style.width = gallery.clientWidth + "px" - preview.style.height = gallery.clientHeight + "px" - if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px" +function rememberGallerySelection(id_gallery){ + storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery) +} - //only watch gallery if there is a generation process going on - check_gallery(id_gallery); +function getGallerySelectedIndex(id_gallery){ + let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') + let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') - var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; - if(progressDiv){ - timeoutIds[id_part] = window.setTimeout(function() { - timeoutIds[id_part] = null - requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) - }, 500) - } else{ - if (skip) { - skip.style.display = "none" - } - interrupt.style.display = "none" + let currentlySelectedIndex = -1 + galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } }) - //disconnect observer once generation finished, so user can close selected image if they want - if (galleryObservers[id_gallery]) { - galleryObservers[id_gallery].disconnect(); - galleries[id_gallery] = null; - } - } - } - - }); - mutationObserver.observe( progressbar, { childList:true, subtree:true }) - } + return currentlySelectedIndex } +// this is a workaround for https://github.com/gradio-app/gradio/issues/2984 function check_gallery(id_gallery){ let gallery = gradioApp().getElementById(id_gallery) // if gallery has no change, no need to setting up observer again. @@ -85,10 +28,16 @@ function check_gallery(id_gallery){ if(galleryObservers[id_gallery]){ galleryObservers[id_gallery].disconnect(); } - let prevSelectedIndex = selected_gallery_index(); + + storedGallerySelections[id_gallery] = -1 + galleryObservers[id_gallery] = new MutationObserver(function (){ let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') + let currentlySelectedIndex = getGallerySelectedIndex(id_gallery) + prevSelectedIndex = storedGallerySelections[id_gallery] + storedGallerySelections[id_gallery] = -1 + if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { // automatically re-open previously selected index (if exists) activeElement = gradioApp().activeElement; @@ -120,30 +69,150 @@ function check_gallery(id_gallery){ } onUiUpdate(function(){ - check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') - check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') - check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery') + check_gallery('txt2img_gallery') + check_gallery('img2img_gallery') }) -function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){ - btn = gradioApp().getElementById(id_part+"_check_progress"); - if(btn==null) return; - - btn.click(); - var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; - var skip = id_skip ? gradioApp().getElementById(id_skip) : null - var interrupt = gradioApp().getElementById(id_interrupt) - if(progressDiv && interrupt){ - if (skip) { - skip.style.display = "block" +function request(url, data, handler, errorHandler){ + var xhr = new XMLHttpRequest(); + var url = url; + xhr.open("POST", url, true); + xhr.setRequestHeader("Content-Type", "application/json"); + xhr.onreadystatechange = function () { + if (xhr.readyState === 4) { + if (xhr.status === 200) { + var js = JSON.parse(xhr.responseText); + handler(js) + } else{ + errorHandler() + } } - interrupt.style.display = "block" + }; + var js = JSON.stringify(data); + xhr.send(js); +} + +function pad2(x){ + return x<10 ? '0'+x : x +} + +function formatTime(secs){ + if(secs > 3600){ + return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60) + } else if(secs > 60){ + return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60) + } else{ + return Math.floor(secs) + "s" } } -function requestProgress(id_part){ - btn = gradioApp().getElementById(id_part+"_check_progress_initial"); - if(btn==null) return; +function randomId(){ + return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" +} + +// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and +// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd. +// calls onProgress every time there is a progress update +function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){ + var dateStart = new Date() + var wasEverActive = false + var parentProgressbar = progressbarContainer.parentNode + var parentGallery = gallery.parentNode + + var divProgress = document.createElement('div') + divProgress.className='progressDiv' + var divInner = document.createElement('div') + divInner.className='progress' + + divProgress.appendChild(divInner) + parentProgressbar.insertBefore(divProgress, progressbarContainer) + + var livePreview = document.createElement('div') + livePreview.className='livePreview' + parentGallery.insertBefore(livePreview, gallery) + + var removeProgressBar = function(){ + parentProgressbar.removeChild(divProgress) + parentGallery.removeChild(livePreview) + atEnd() + } + + var fun = function(id_task, id_live_preview){ + request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){ + console.log(res) + + if(res.completed){ + removeProgressBar() + return + } + + var rect = progressbarContainer.getBoundingClientRect() + + if(rect.width){ + divProgress.style.width = rect.width + "px"; + } + + progressText = "" + + divInner.style.width = ((res.progress || 0) * 100.0) + '%' + + if(res.progress > 0){ + progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%' + } + + if(res.eta){ + progressText += " ETA: " + formatTime(res.eta) + } else if(res.textinfo){ + progressText += " " + res.textinfo + } + + divInner.textContent = progressText + + var elapsedFromStart = (new Date() - dateStart) / 1000 + + if(res.active) wasEverActive = true; + + if(! res.active && wasEverActive){ + removeProgressBar() + return + } + + if(elapsedFromStart > 5 && !res.queued && !res.active){ + removeProgressBar() + return + } + + + if(res.live_preview){ + var img = new Image(); + img.onload = function() { + var rect = gallery.getBoundingClientRect() + if(rect.width){ + livePreview.style.width = rect.width + "px" + livePreview.style.height = rect.height + "px" + } + + livePreview.innerHTML = '' + livePreview.appendChild(img) + if(livePreview.childElementCount > 2){ + livePreview.removeChild(livePreview.firstElementChild) + } + } + img.src = res.live_preview; + } + + + if(onProgress){ + onProgress(res) + } + + setTimeout(() => { + fun(id_task, res.id_live_preview); + }, 500) + }, function(){ + removeProgressBar() + }) + } - btn.click(); + fun(id_task, 0) } diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js index 8061be08..0354b860 100644 --- a/javascript/textualInversion.js +++ b/javascript/textualInversion.js @@ -1,8 +1,17 @@ + function start_training_textual_inversion(){ - requestProgress('ti') gradioApp().querySelector('#ti_error').innerHTML='' - return args_to_array(arguments) + var id = randomId() + requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){ + gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo + }) + + var res = args_to_array(arguments) + + res[0] = id + + return res } diff --git a/javascript/ui.js b/javascript/ui.js index f8279124..ecf97cb3 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -126,18 +126,41 @@ function create_submit_args(args){ return res } +function showSubmitButtons(tabname, show){ + gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block" + gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block" +} + function submit(){ - requestProgress('txt2img') + rememberGallerySelection('txt2img_gallery') + showSubmitButtons('txt2img', false) + + var id = randomId() + requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){ + showSubmitButtons('txt2img', true) + + }) - return create_submit_args(arguments) + var res = create_submit_args(arguments) + + res[0] = id + + return res } function submit_img2img(){ - requestProgress('img2img') + rememberGallerySelection('img2img_gallery') + showSubmitButtons('img2img', false) + + var id = randomId() + requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ + showSubmitButtons('img2img', true) + }) - res = create_submit_args(arguments) + var res = create_submit_args(arguments) - res[0] = get_tab_index('mode_img2img') + res[0] = id + res[1] = get_tab_index('mode_img2img') return res } diff --git a/modules/call_queue.py b/modules/call_queue.py index 4cd49533..92097c15 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -4,7 +4,7 @@ import threading import traceback import time -from modules import shared +from modules import shared, progress queue_lock = threading.Lock() @@ -22,12 +22,23 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): - shared.state.begin() + # if the first argument is a string that says "task(...)", it is treated as a job id + if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": + id_task = args[0] + progress.add_task_to_queue(id_task) + else: + id_task = None with queue_lock: - res = func(*args, **kwargs) + shared.state.begin() + progress.start_task(id_task) + + try: + res = func(*args, **kwargs) + finally: + progress.finish_task(id_task) - shared.state.end() + shared.state.end() return res diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3aebefa8..ae6af516 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -453,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -629,7 +629,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" pbar.set_description(description) - shared.state.textinfo = description if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' @@ -701,7 +700,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, torch.cuda.set_rng_state_all(cuda_rng_state) hypernetwork.train() if image is not None: - shared.state.current_image = image + shared.state.assign_current_image(image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" diff --git a/modules/img2img.py b/modules/img2img.py index f62783c6..f4a03c57 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_batch = mode == 5 if mode == 0: # img2img diff --git a/modules/progress.py b/modules/progress.py new file mode 100644 index 00000000..3327b883 --- /dev/null +++ b/modules/progress.py @@ -0,0 +1,96 @@ +import base64 +import io +import time + +import gradio as gr +from pydantic import BaseModel, Field + +from modules.shared import opts + +import modules.shared as shared + + +current_task = None +pending_tasks = {} +finished_tasks = [] + + +def start_task(id_task): + global current_task + + current_task = id_task + pending_tasks.pop(id_task, None) + + +def finish_task(id_task): + global current_task + + if current_task == id_task: + current_task = None + + finished_tasks.append(id_task) + if len(finished_tasks) > 16: + finished_tasks.pop(0) + + +def add_task_to_queue(id_job): + pending_tasks[id_job] = time.time() + + +class ProgressRequest(BaseModel): + id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") + id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image") + + +class ProgressResponse(BaseModel): + active: bool = Field(title="Whether the task is being worked on right now") + queued: bool = Field(title="Whether the task is in queue") + completed: bool = Field(title="Whether the task has already finished") + progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") + eta: float = Field(default=None, title="ETA in secs") + live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri") + id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image") + textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") + + +def setup_progress_api(app): + return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) + + +def progressapi(req: ProgressRequest): + active = req.id_task == current_task + queued = req.id_task in pending_tasks + completed = req.id_task in finished_tasks + + if not active: + return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...") + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + progress = min(progress, 1) + + elapsed_since_start = time.time() - shared.state.time_start + predicted_duration = elapsed_since_start / progress if progress > 0 else None + eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None + + id_live_preview = req.id_live_preview + shared.state.set_current_image() + if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview: + image = shared.state.current_image + if image is not None: + buffered = io.BytesIO() + image.save(buffered, format="png") + live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii") + id_live_preview = shared.state.id_live_preview + else: + live_preview = None + else: + live_preview = None + + return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) + diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 7616fded..76e0e0d5 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -140,7 +140,7 @@ def store_latent(decoded): if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if not shared.parallel_processing_allowed: - shared.state.current_image = sample_to_image(decoded) + shared.state.assign_current_image(sample_to_image(decoded)) class InterruptedException(BaseException): diff --git a/modules/shared.py b/modules/shared.py index 51df056c..de99aca9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -152,6 +152,7 @@ def reload_hypernetworks(): hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + class State: skipped = False interrupted = False @@ -165,6 +166,7 @@ class State: current_latent = None current_image = None current_image_sampling_step = 0 + id_live_preview = 0 textinfo = None time_start = None need_restart = False @@ -207,6 +209,7 @@ class State: self.current_latent = None self.current_image = None self.current_image_sampling_step = 0 + self.id_live_preview = 0 self.skipped = False self.interrupted = False self.textinfo = None @@ -220,8 +223,8 @@ class State: devices.torch_gc() - """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" def set_current_image(self): + """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" if not parallel_processing_allowed: return @@ -234,12 +237,16 @@ class State: import modules.sd_samplers if opts.show_progress_grid: - self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent) + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) else: - self.current_image = modules.sd_samplers.sample_to_image(self.current_latent) + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) self.current_image_sampling_step = self.sampling_step + def assign_current_image(self, image): + self.current_image = image + self.id_live_preview += 1 + state = State() state.server_start = time.time() @@ -424,8 +431,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), })) options_templates.update(options_section(('ui', "User interface"), { - "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), @@ -446,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "Live previews"), { "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), + "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 3c1042ad..64abff4d 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop -def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): +def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): try: if process_caption: shared.interrogator.load() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 63935878..7e4a6d24 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -345,7 +345,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None) @@ -510,7 +510,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" pbar.set_description(description) - shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' @@ -560,7 +559,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ shared.sd_model.first_stage_model.to(devices.cpu) if image is not None: - shared.state.current_image = image + shared.state.assign_current_image(image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" diff --git a/modules/txt2img.py b/modules/txt2img.py index 38b5f591..ca5d4550 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, diff --git a/modules/ui.py b/modules/ui.py index 2425c66f..ff33236b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -356,7 +356,7 @@ def create_toprow(is_img2img): button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") with gr.Column(scale=1): - with gr.Row(): + with gr.Row(elem_id=f"{id_part}_generate_box"): skip = gr.Button('Skip', elem_id=f"{id_part}_skip") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') @@ -384,9 +384,7 @@ def create_toprow(is_img2img): def setup_progressbar(*args, **kwargs): - import modules.ui_progress - - modules.ui_progress.setup_progressbar(*args, **kwargs) + pass def apply_setting(key, value): @@ -479,8 +477,8 @@ Requested path was: {f} else: sp.Popen(["xdg-open", path]) - with gr.Column(variant='panel'): - with gr.Group(): + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) generation_info = None @@ -595,15 +593,6 @@ def create_ui(): dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - with gr.Row().style(equal_height=False): with gr.Column(variant='panel', elem_id="txt2img_settings"): for category in ordered_ui_categories(): @@ -682,6 +671,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", inputs=[ + dummy_component, txt2img_prompt, txt2img_negative_prompt, txt2img_prompt_style, @@ -782,16 +772,7 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) with FormRow().style(equal_height=False): with gr.Column(variant='panel', elem_id="img2img_settings"): @@ -958,6 +939,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", inputs=[ + dummy_component, dummy_component, img2img_prompt, img2img_negative_prompt, @@ -1335,15 +1317,11 @@ def create_ui(): script_callbacks.ui_train_tabs_callback(params) - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") + with gr.Column(elem_id='ti_gallery_container'): ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) ti_progress = gr.HTML(elem_id="ti_progress", value="") ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) create_embedding.click( fn=modules.textual_inversion.ui.create_embedding, @@ -1384,6 +1362,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, process_src, process_dst, process_width, @@ -1411,6 +1390,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, train_embedding_name, embedding_learn_rate, batch_size, @@ -1443,6 +1423,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, train_hypernetwork_name, hypernetwork_learn_rate, batch_size, diff --git a/modules/ui_progress.py b/modules/ui_progress.py deleted file mode 100644 index 7cd312e4..00000000 --- a/modules/ui_progress.py +++ /dev/null @@ -1,101 +0,0 @@ -import time - -import gradio as gr - -from modules.shared import opts - -import modules.shared as shared - - -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr.update(visible=False) - preview_visibility = gr.update(visible=False) - - if opts.live_previews_enable: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr.update(visible=True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr.update(visible=False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) diff --git a/style.css b/style.css index 2d484e06..786b71d1 100644 --- a/style.css +++ b/style.css @@ -305,26 +305,42 @@ input[type="range"]{ } .progressDiv{ - width: 100%; - height: 20px; - background: #b4c0cc; - border-radius: 8px; + position: absolute; + height: 20px; + top: -20px; + background: #b4c0cc; + border-radius: 8px !important; } .dark .progressDiv{ - background: #424c5b; + background: #424c5b; } .progressDiv .progress{ - width: 0%; - height: 20px; - background: #0060df; - color: white; - font-weight: bold; - line-height: 20px; - padding: 0 8px 0 0; - text-align: right; - border-radius: 8px; + width: 0%; + height: 20px; + background: #0060df; + color: white; + font-weight: bold; + line-height: 20px; + padding: 0 8px 0 0; + text-align: right; + border-radius: 8px; + overflow: visible; + white-space: nowrap; +} + +.livePreview{ + position: absolute; + z-index: 300; + background-color: white; + margin: -4px; +} + +.livePreview img{ + object-fit: contain; + width: 100%; + height: 100%; } #lightboxModal{ @@ -450,23 +466,25 @@ input[type="range"]{ display:none } -#txt2img_interrupt, #img2img_interrupt{ - position: absolute; - width: 50%; - height: 72px; - background: #b4c0cc; - border-radius: 0px; - display: none; +#txt2img_generate_box, #img2img_generate_box{ + position: relative; +} + +#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{ + position: absolute; + width: 50%; + height: 100%; + background: #b4c0cc; + display: none; } +#txt2img_interrupt, #img2img_interrupt{ + right: 0; + border-radius: 0 0.5rem 0.5rem 0; +} #txt2img_skip, #img2img_skip{ - position: absolute; - width: 50%; - right: 0px; - height: 72px; - background: #b4c0cc; - border-radius: 0px; - display: none; + left: 0; + border-radius: 0.5rem 0 0 0.5rem; } .red { diff --git a/webui.py b/webui.py index 1fff80da..4624fe18 100644 --- a/webui.py +++ b/webui.py @@ -34,6 +34,7 @@ import modules.sd_vae import modules.txt2img import modules.script_callbacks import modules.textual_inversion.textual_inversion +import modules.progress import modules.ui from modules import modelloader @@ -181,6 +182,8 @@ def webui(): app.add_middleware(GZipMiddleware, minimum_size=1000) + modules.progress.setup_progress_api(app) + if launch_api: create_api(app) -- cgit v1.2.3 From 40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 08:36:07 +0300 Subject: extra networks UI rework of hypernets: rather than via settings, hypernets are added directly to prompt as --- html/card-no-preview.png | Bin 0 -> 84440 bytes html/extra-networks-card.html | 11 ++ html/extra-networks-no-cards.html | 8 ++ javascript/extraNetworks.js | 60 ++++++++ javascript/hints.js | 2 + javascript/ui.js | 9 +- modules/api/api.py | 7 +- modules/extra_networks.py | 147 +++++++++++++++++++ modules/extra_networks_hypernet.py | 21 +++ modules/generation_parameters_copypaste.py | 12 +- modules/hypernetworks/hypernetwork.py | 107 +++++++++----- modules/hypernetworks/ui.py | 5 +- modules/processing.py | 24 ++-- modules/sd_hijack_optimizations.py | 10 +- modules/shared.py | 21 ++- modules/textual_inversion/textual_inversion.py | 2 + modules/ui.py | 50 ++++--- modules/ui_components.py | 10 ++ modules/ui_extra_networks.py | 149 +++++++++++++++++++ modules/ui_extra_networks_hypernets.py | 34 +++++ modules/ui_extra_networks_textual_inversion.py | 32 +++++ script.js | 13 +- scripts/xy_grid.py | 29 ---- style.css | 190 +++++++++++++------------ webui.py | 26 +++- 25 files changed, 765 insertions(+), 214 deletions(-) create mode 100644 html/card-no-preview.png create mode 100644 html/extra-networks-card.html create mode 100644 html/extra-networks-no-cards.html create mode 100644 javascript/extraNetworks.js create mode 100644 modules/extra_networks.py create mode 100644 modules/extra_networks_hypernet.py create mode 100644 modules/ui_extra_networks.py create mode 100644 modules/ui_extra_networks_hypernets.py create mode 100644 modules/ui_extra_networks_textual_inversion.py (limited to 'webui.py') diff --git a/html/card-no-preview.png b/html/card-no-preview.png new file mode 100644 index 00000000..e2beb269 Binary files /dev/null and b/html/card-no-preview.png differ diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html new file mode 100644 index 00000000..7314b063 --- /dev/null +++ b/html/extra-networks-card.html @@ -0,0 +1,11 @@ +
+
+
+ +
+ {name} +
+
+ diff --git a/html/extra-networks-no-cards.html b/html/extra-networks-no-cards.html new file mode 100644 index 00000000..389358d6 --- /dev/null +++ b/html/extra-networks-no-cards.html @@ -0,0 +1,8 @@ +
+

Nothing here. Add some content to the following directories:

+ +
    +{dirs} +
+
+ diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js new file mode 100644 index 00000000..71e522d1 --- /dev/null +++ b/javascript/extraNetworks.js @@ -0,0 +1,60 @@ + +function setupExtraNetworksForTab(tabname){ + gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') + + gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh')) + gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close')) +} + +var activePromptTextarea = null; +var activePositivePromptTextarea = null; + +function setupExtraNetworks(){ + setupExtraNetworksForTab('txt2img') + setupExtraNetworksForTab('img2img') + + function registerPrompt(id, isNegative){ + var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); + + if (activePromptTextarea == null){ + activePromptTextarea = textarea + } + if (activePositivePromptTextarea == null && ! isNegative){ + activePositivePromptTextarea = textarea + } + + textarea.addEventListener("focus", function(){ + activePromptTextarea = textarea; + if(! isNegative) activePositivePromptTextarea = textarea; + }); + } + + registerPrompt('txt2img_prompt') + registerPrompt('txt2img_neg_prompt', true) + registerPrompt('img2img_prompt') + registerPrompt('img2img_neg_prompt', true) +} + +onUiLoaded(setupExtraNetworks) + +function cardClicked(textToAdd, allowNegativePrompt){ + textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea + + textarea.value = textarea.value + " " + textToAdd + updateInput(textarea) + + return false +} + +function saveCardPreview(event, tabname, filename){ + textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea') + button = gradioApp().getElementById(tabname + '_save_preview') + + textarea.value = filename + updateInput(textarea) + + button.click() + + event.stopPropagation() + event.preventDefault() +} diff --git a/javascript/hints.js b/javascript/hints.js index e746e20d..f4079f96 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -21,6 +21,8 @@ titles = { "\U0001F5D1": "Clear prompt", "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", + "\u{1f3b4}": "Show extra networks", + "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/javascript/ui.js b/javascript/ui.js index 3ba90ca8..a7e75439 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) { return [prompt, negative_prompt] } - - opts = {} onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; @@ -239,11 +237,14 @@ onUiUpdate(function(){ return } + prompt.parentElement.insertBefore(counter, prompt) counter.classList.add("token-counter") prompt.parentElement.style.position = "relative" - textarea.addEventListener("input", () => update_token_counter(id_button)); + textarea.addEventListener("input", function(){ + update_token_counter(id_button); + }); } registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button') @@ -261,10 +262,8 @@ onUiUpdate(function(){ }) } } - }) - onOptionsChanged(function(){ elem = gradioApp().getElementById('sd_checkpoint_hash') sd_checkpoint_hash = opts.sd_checkpoint_hash || "" diff --git a/modules/api/api.py b/modules/api/api.py index 9814bbc2..2c371e6e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -480,7 +480,7 @@ class Api: def train_hypernetwork(self, args: dict): try: shared.state.begin() - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -491,16 +491,15 @@ class Api: except Exception as e: error = e finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) except AssertionError as msg: shared.state.end() - return TrainResponse(info = "train embedding error: {error}".format(error = error)) + return TrainResponse(info="train embedding error: {error}".format(error=error)) def get_memory(self): try: diff --git a/modules/extra_networks.py b/modules/extra_networks.py new file mode 100644 index 00000000..1978673d --- /dev/null +++ b/modules/extra_networks.py @@ -0,0 +1,147 @@ +import re +from collections import defaultdict + +from modules import errors + +extra_network_registry = {} + + +def initialize(): + extra_network_registry.clear() + + +def register_extra_network(extra_network): + extra_network_registry[extra_network.name] = extra_network + + +class ExtraNetworkParams: + def __init__(self, items=None): + self.items = items or [] + + +class ExtraNetwork: + def __init__(self, name): + self.name = name + + def activate(self, p, params_list): + """ + Called by processing on every run. Whatever the extra network is meant to do should be activated here. + Passes arguments related to this extra network in params_list. + User passes arguments by specifying this in his prompt: + + + + Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments + separated by colon. + + Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - + in this case, all effects of this extra networks should be disabled. + + Can be called multiple times before deactivate() - each new call should override the previous call completely. + + For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is: + + > "1girl, " + + params_list will be: + + [ + ExtraNetworkParams(items=["agm", "1.1"]), + ExtraNetworkParams(items=["ray"]) + ] + + """ + raise NotImplementedError + + def deactivate(self, p): + """ + Called at the end of processing for housekeeping. No need to do anything here. + """ + + raise NotImplementedError + + +def activate(p, extra_network_data): + """call activate for extra networks in extra_network_data in specified order, then call + activate for all remaining registered networks with an empty argument list""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + print(f"Skipping unknown extra network: {extra_network_name}") + continue + + try: + extra_network.activate(p, extra_network_args) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.activate(p, []) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name}") + + +def deactivate(p, extra_network_data): + """call deactivate for extra networks in extra_network_data in specified order, then call + deactivate for all remaining registered networks""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating extra network {extra_network_name}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating unmentioned extra network {extra_network_name}") + + +re_extra_net = re.compile(r"<(\w+):([^>]+)>") + + +def parse_prompt(prompt): + res = defaultdict(list) + + def found(m): + name = m.group(1) + args = m.group(2) + + res[name].append(ExtraNetworkParams(items=args.split(":"))) + + return "" + + prompt = re.sub(re_extra_net, found, prompt) + + return prompt, res + + +def parse_prompts(prompts): + res = [] + extra_data = None + + for prompt in prompts: + updated_prompt, parsed_extra_data = parse_prompt(prompt) + + if extra_data is None: + extra_data = parsed_extra_data + + res.append(updated_prompt) + + return res, extra_data + diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py new file mode 100644 index 00000000..6a0c4ba8 --- /dev/null +++ b/modules/extra_networks_hypernet.py @@ -0,0 +1,21 @@ +from modules import extra_networks +from modules.hypernetworks import hypernetwork + + +class ExtraNetworkHypernet(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('hypernet') + + def activate(self, p, params_list): + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + hypernetwork.load_hypernetworks(names, multipliers) + + def deactivate(p, self): + pass diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a381ff59..46e12dc6 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict): from modules import ui settings_map = { - 'sd_hypernetwork': 'Hypernet', - 'sd_hypernetwork_strength': 'Hypernet strength', 'CLIP_stop_at_last_layers': 'Clip skip', 'inpainting_mask_weight': 'Conditional mask weight', 'sd_model_checkpoint': 'Model hash', @@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Clip skip" not in res: res["Clip skip"] = "1" - if "Hypernet strength" not in res: - res["Hypernet strength"] = "1" - - if "Hypernet" in res: - hypernet_name = res["Hypernet"] - hypernet_hash = res.get("Hypernet hash", None) - res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + hypernet = res.get("Hypernet", None) + if hypernet is not None: + res["Prompt"] += f"""""" if "Hires resize-1" not in res: res["Hires resize-1"] = 0 diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74e78582..80a47c79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -25,7 +25,6 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} class HypernetworkModule(torch.nn.Module): - multiplier = 1.0 activation_dict = { "linear": torch.nn.Identity, "relu": torch.nn.ReLU, @@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module): add_layer_norm=False, activate_output=False, dropout_structure=None): super().__init__() + self.multiplier = 1.0 + assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" @@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x def forward(self, x): - return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1) + return x + self.linear(x) * (self.multiplier if not self.training else 1) def trainables(self): layer_structure = [] @@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module): return layer_structure -def apply_strength(value=None): - HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength - #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check. def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout): if layer_structure is None: @@ -192,6 +190,20 @@ class Hypernetwork: for param in layer.parameters(): param.requires_grad = mode + def to(self, device): + for k, layers in self.layers.items(): + for layer in layers: + layer.to(device) + + return self + + def set_multiplier(self, multiplier): + for k, layers in self.layers.items(): + for layer in layers: + layer.multiplier = multiplier + + return self + def eval(self): for k, layers in self.layers.items(): for layer in layers: @@ -269,11 +281,13 @@ class Hypernetwork: self.optimizer_state_dict = None if self.optimizer_state_dict: self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') - print("Loaded existing optimizer from checkpoint") - print(f"Optimizer name is {self.optimizer_name}") + if shared.opts.print_hypernet_extra: + print("Loaded existing optimizer from checkpoint") + print(f"Optimizer name is {self.optimizer_name}") else: self.optimizer_name = "AdamW" - print("No saved optimizer exists in checkpoint") + if shared.opts.print_hypernet_extra: + print("No saved optimizer exists in checkpoint") for size, sd in state_dict.items(): if type(size) == int: @@ -306,23 +320,43 @@ def list_hypernetworks(path): return res -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - # Prevent any file named "None.pt" from being loaded. - if path is not None and filename != "None": - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) +def load_hypernetwork(name): + path = shared.hypernetworks.get(name, None) - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print("Unloading hypernetwork") + if path is None: + return None + + hypernetwork = Hypernetwork() + + try: + hypernetwork.load(path) + except Exception: + print(f"Error loading hypernetwork {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + return hypernetwork + + +def load_hypernetworks(names, multipliers=None): + already_loaded = {} + + for hypernetwork in shared.loaded_hypernetworks: + if hypernetwork.name in names: + already_loaded[hypernetwork.name] = hypernetwork - shared.loaded_hypernetwork = None + shared.loaded_hypernetworks.clear() + + for i, name in enumerate(names): + hypernetwork = already_loaded.get(name, None) + if hypernetwork is None: + hypernetwork = load_hypernetwork(name) + + if hypernetwork is None: + continue + + hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0) + shared.loaded_hypernetworks.append(hypernetwork) def find_closest_hypernetwork_name(search: str): @@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str): return applicable[0] -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) +def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) if hypernetwork_layers is None: - return context, context + return context_k, context_v if layer is not None: layer.hyper_k = hypernetwork_layers[0] layer.hyper_v = hypernetwork_layers[1] - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) + context_k = hypernetwork_layers[0](context_k) + context_v = hypernetwork_layers[1](context_v) + return context_k, context_v + + +def apply_hypernetworks(hypernetworks, context, layer=None): + context_k = context + context_v = context + for hypernetwork in hypernetworks: + context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer) + return context_k, context_v @@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) + context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self) k = self.to_k(context_k) v = self.to_v(context_v) @@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi template_file = template_file.path path = shared.hypernetworks.get(hypernetwork_name, None) - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) + hypernetwork = Hypernetwork() + hypernetwork.load(path) + shared.loaded_hypernetworks = [hypernetwork] shared.state.job = "train-hypernetwork" shared.state.textinfo = "Initializing hypernetwork training..." @@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi else: images_dir = None - hypernetwork = shared.loaded_hypernetwork checkpoint = sd_models.select_checkpoint() initial_step = hypernetwork.step or 0 diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 81e3f519..76599f5a 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"] keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) + def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) @@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(*args): - - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' @@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception: raise finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) sd_hijack.apply_optimizations() diff --git a/modules/processing.py b/modules/processing.py index a3e9f709..b5deeacf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), - "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()), - "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), @@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': - shared.reload_hypernetworks() # make onchange call for changing hypernet if k == 'sd_model_checkpoint': - sd_models.reload_model_weights() # make onchange call for changing SD model + sd_models.reload_model_weights() if k == 'sd_vae': - sd_vae.reload_vae_weights() # make onchange call for changing VAE + sd_vae.reload_vae_weights() res = process_images_inner(p) @@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': shared.reload_hypernetworks() - if k == 'sd_model_checkpoint': sd_models.reload_model_weights() - if k == 'sd_vae': sd_vae.reload_vae_weights() + if k == 'sd_model_checkpoint': + sd_models.reload_model_weights() + + if k == 'sd_vae': + sd_vae.reload_vae_weights() return res @@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: cache[0] = (required_prompts, steps) return cache[1] + p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts) + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) + extra_networks.activate(p, extra_network_data) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) @@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + extra_networks.deactivate(p, extra_network_data) devices.torch_gc() res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index cdc63ed7..4fa54329 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) del context, context_k, context_v, x @@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) @@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) * self.scale v = self.to_v(context_v) del context, context_k, context_v, x @@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x @@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) diff --git a/modules/shared.py b/modules/shared.py index 2f366454..c0e11f18 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -23,6 +23,7 @@ demo = None sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file + parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) @@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = {} -loaded_hypernetwork = None +loaded_hypernetworks = [] def reload_hypernetworks(): @@ -153,8 +154,6 @@ def reload_hypernetworks(): global hypernetworks hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) - hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - class State: @@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), - "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), - "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), @@ -661,3 +658,17 @@ mem_mon.start() def listfiles(dirname): filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] return [file for file in filenames if os.path.isfile(file)] + + +def html_path(filename): + return os.path.join(script_path, "html", filename) + + +def html(filename): + path = html_path(filename) + + if os.path.exists(path): + with open(path, encoding="utf8") as file: + return file.read() + + return "" diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5a7be422..4e90f690 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -50,6 +50,7 @@ class Embedding: self.sd_checkpoint = None self.sd_checkpoint_name = None self.optimizer_state_dict = None + self.filename = None def save(self, filename): embedding_data = { @@ -182,6 +183,7 @@ class EmbeddingDatabase: embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] + embedding.filename = path if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) diff --git a/modules/ui.py b/modules/ui.py index 06c11848..d23b2b8e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001F5D1' # 🗑️ +extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): @@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps): try: + text, _ = extra_networks.parse_prompt(text) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) @@ -354,10 +357,10 @@ def create_toprow(is_img2img): negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + paste = ToolButton(value=paste_symbol, elem_id="paste") + clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") @@ -395,11 +398,14 @@ def create_toprow(is_img2img): outputs=[], ) - with gr.Row(): + with gr.Row(elem_id=f"{id_part}_styles_row"): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button + prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply") + save_style = ToolButton(value=save_style_symbol, elem_id="style_create") + + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button def setup_progressbar(*args, **kwargs): @@ -616,11 +622,15 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) + with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img') + with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): for category in ordered_ui_categories(): @@ -794,14 +804,20 @@ def create_ui(): token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) + with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img') + with FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): copy_image_buttons = [] @@ -1064,6 +1080,8 @@ def create_ui(): token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + img2img_paste_fields = [ (img2img_prompt, "Prompt"), (img2img_negative_prompt, "Negative prompt"), @@ -1666,10 +1684,8 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") + with gr.TabItem("Licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.Button(value="Show all pages", elem_id="settings_show_all_pages") @@ -1756,11 +1772,9 @@ def create_ui(): if os.path.exists(os.path.join(script_path, "notification.mp3")): audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") + footer = shared.html("footer.html") + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) settings_submit.click( diff --git a/modules/ui_components.py b/modules/ui_components.py index 97acff06..46324425 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent): return "button" +class ToolButtonTop(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, with extra margin at top, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool-top", **kwargs) + + def get_block_name(self): + return "button" + + class FormRow(gr.Row, gr.components.FormComponent): """Same as gr.Row but fits inside gradio forms""" diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py new file mode 100644 index 00000000..253e90f7 --- /dev/null +++ b/modules/ui_extra_networks.py @@ -0,0 +1,149 @@ +import os.path + +from modules import shared +import gradio as gr +import json + +from modules.generation_parameters_copypaste import image_from_url_text + +extra_pages = [] + + +def register_page(page): + """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions""" + + extra_pages.append(page) + + +class ExtraNetworksPage: + def __init__(self, title): + self.title = title + self.card_page = shared.html("extra-networks-card.html") + self.allow_negative_prompt = False + + def refresh(self): + pass + + def create_html(self, tabname): + items_html = '' + + for item in self.list_items(): + items_html += self.create_html_for_item(item, tabname) + + if items_html == '': + dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) + items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) + + res = "
    " + items_html + "
    " + + return res + + def list_items(self): + raise NotImplementedError() + + def allowed_directories_for_previews(self): + return [] + + def create_html_for_item(self, item, tabname): + preview = item.get("preview", None) + + args = { + "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', + "prompt": json.dumps(item["prompt"]), + "tabname": json.dumps(tabname), + "local_preview": json.dumps(item["local_preview"]), + "name": item["name"], + "allow_negative_prompt": "true" if self.allow_negative_prompt else "false", + } + + return self.card_page.format(**args) + + +def intialize(): + extra_pages.clear() + + +class ExtraNetworksUi: + def __init__(self): + self.pages = None + self.stored_extra_pages = None + + self.button_save_preview = None + self.preview_target_filename = None + + self.tabname = None + + +def create_ui(container, button, tabname): + ui = ExtraNetworksUi() + ui.pages = [] + ui.stored_extra_pages = extra_pages.copy() + ui.tabname = tabname + + with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + + for page in ui.stored_extra_pages: + with gr.Tab(page.title): + page_elem = gr.HTML(page.create_html(ui.tabname)) + ui.pages.append(page_elem) + + ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + + button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) + button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) + + def refresh(): + res = [] + + for pg in ui.stored_extra_pages: + pg.refresh() + res.append(pg.create_html(ui.tabname)) + + return res + + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + + return ui + + +def path_is_parent(parent_path, child_path): + parent_path = os.path.abspath(parent_path) + child_path = os.path.abspath(child_path) + + return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path]) + + +def setup_ui(ui, gallery): + def save_preview(index, images, filename): + if len(images) == 0: + print("There is no image in gallery to save as a preview.") + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + index = int(index) + index = 0 if index < 0 else index + index = len(images) - 1 if index >= len(images) else index + + img_info = images[index if index >= 0 else 0] + image = image_from_url_text(img_info) + + is_allowed = False + for extra_page in ui.stored_extra_pages: + if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]): + is_allowed = True + break + + assert is_allowed, f'writing to {filename} is not allowed' + + image.save(filename) + + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + ui.button_save_preview.click( + fn=save_preview, + _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + outputs=[*ui.pages] + ) diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py new file mode 100644 index 00000000..312dbaf0 --- /dev/null +++ b/modules/ui_extra_networks_hypernets.py @@ -0,0 +1,34 @@ +import os + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Hypernetworks') + + def refresh(self): + shared.reload_hypernetworks() + + def list_items(self): + for name, path in shared.hypernetworks.items(): + path, ext = os.path.splitext(path) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "prompt": f"", + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.hypernetwork_dir] + diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py new file mode 100644 index 00000000..e4a6e3bf --- /dev/null +++ b/modules/ui_extra_networks_textual_inversion.py @@ -0,0 +1,32 @@ +import os + +from modules import ui_extra_networks, sd_hijack + + +class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Textual Inversion') + self.allow_negative_prompt = True + + def refresh(self): + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + + def list_items(self): + for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): + path, ext = os.path.splitext(embedding.filename) + preview_file = path + ".preview.png" + + preview = None + if os.path.isfile(preview_file): + preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file)) + + yield { + "name": embedding.name, + "filename": embedding.filename, + "preview": preview, + "prompt": embedding.name, + "local_preview": path + ".preview.png", + } + + def allowed_directories_for_previews(self): + return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/script.js b/script.js index 3345e32b..97e0bfcf 100644 --- a/script.js +++ b/script.js @@ -13,6 +13,7 @@ function get_uiCurrentTabContent() { } uiUpdateCallbacks = [] +uiLoadedCallbacks = [] uiTabChangeCallbacks = [] optionsChangedCallbacks = [] let uiCurrentTab = null @@ -20,6 +21,9 @@ let uiCurrentTab = null function onUiUpdate(callback){ uiUpdateCallbacks.push(callback) } +function onUiLoaded(callback){ + uiLoadedCallbacks.push(callback) +} function onUiTabChange(callback){ uiTabChangeCallbacks.push(callback) } @@ -38,8 +42,15 @@ function executeCallbacks(queue, m) { queue.forEach(function(x){runCallback(x, m)}) } +var executedOnLoaded = false; + document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ + if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){ + executedOnLoaded = true; + executeCallbacks(uiLoadedCallbacks); + } + executeCallbacks(uiUpdateCallbacks, m); const newTab = get_uiCurrentTab(); if ( newTab && ( newTab !== uiCurrentTab ) ) { @@ -53,7 +64,7 @@ document.addEventListener("DOMContentLoaded", function() { /** * Add a ctrl+enter as a shortcut to start a generation */ - document.addEventListener('keydown', function(e) { +document.addEventListener('keydown', function(e) { var handled = false; if (e.key !== undefined) { if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6629f5d5..b1badec9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,6 @@ import modules.scripts as scripts import gradio as gr from modules import images, paths, sd_samplers, processing, sd_models, sd_vae -from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs): raise RuntimeError(f"Unknown checkpoint: {x}") -def apply_hypernetwork(p, x, xs): - if x.lower() in ["", "none"]: - name = None - else: - name = hypernetwork.find_closest_hypernetwork_name(x) - if not name: - raise RuntimeError(f"Unknown hypernetwork: {x}") - hypernetwork.load_hypernetwork(name) - - -def apply_hypernetwork_strength(p, x, xs): - hypernetwork.apply_strength(x) - - -def confirm_hypernetworks(p, xs): - for x in xs: - if x.lower() in ["", "none"]: - continue - if not hypernetwork.find_closest_hypernetwork_name(x): - raise RuntimeError(f"Unknown hypernetwork: {x}") - - def apply_clip_skip(p, x, xs): opts.data["CLIP_stop_at_last_layers"] = x @@ -208,8 +185,6 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)), - AxisOption("Hypernet str.", float, apply_hypernetwork_strength), AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma min", float, apply_field("s_tmin")), AxisOption("Sigma max", float, apply_field("s_tmax")), @@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ class SharedSettingsStackHelper(object): def __enter__(self): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers - self.hypernetwork = opts.sd_hypernetwork self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): @@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object): modules.sd_models.reload_model_weights() modules.sd_vae.reload_vae_weights() - hypernetwork.load_hypernetwork(self.hypernetwork) - hypernetwork.apply_strength() - opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers diff --git a/style.css b/style.css index 3a515ebd..5e8bc2ca 100644 --- a/style.css +++ b/style.css @@ -132,13 +132,6 @@ } #roll_col > button { - min-width: 2em; - min-height: 2em; - max-width: 2em; - max-height: 2em; - flex-grow: 0; - padding-left: 0.25em; - padding-right: 0.25em; margin: 0.1em 0; } @@ -146,9 +139,10 @@ min-width: 0 !important; max-width: 8em !important; margin-right: 1em; + gap: 0; } #interrogate, #deepbooru{ - margin: 0em 0.25em 0.9em 0.25em; + margin: 0em 0.25em 0.5em 0.25em; min-width: 8em; max-width: 8em; } @@ -157,8 +151,17 @@ min-width: 8em !important; } +#txt2img_styles_row, #img2img_styles_row{ + gap: 0.25em; + margin-top: 0.5em; +} + +#txt2img_styles_row > button, #img2img_styles_row > button{ + margin: 0; +} + #txt2img_styles, #img2img_styles{ - margin-top: 1em; + padding: 0; } #txt2img_styles ul, #img2img_styles ul{ @@ -635,17 +638,21 @@ canvas[key="mask"] { background-color: rgb(31 41 55 / var(--tw-bg-opacity)); } -.gr-button-tool{ +.gr-button-tool, .gr-button-tool-top{ max-width: 2.5em; min-width: 2.5em !important; height: 2.4em; - margin: 1.6em 0.7em 0.55em 0; } -#tab_modelmerger .gr-button-tool{ +.gr-button-tool{ margin: 0.6em 0em 0.55em 0; } +.gr-button-tool-top, #settings .gr-button-tool{ + margin: 1.6em 0.7em 0.55em 0; +} + + #modelmerger_results_container{ margin-top: 1em; overflow: visible; @@ -763,81 +770,88 @@ footer { line-height: 2.4em; } -/* The following handles localization for right-to-left (RTL) languages like Arabic. -The rtl media type will only be activated by the logic in javascript/localization.js. -If you change anything above, you need to make sure it is RTL compliant by just running -your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/. -Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/ -@media rtl { - /* this part was added manually */ - :host { - direction: rtl; - } - select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress { - direction: ltr; - } - #script_list > label > select, - #x_type > label > select, - #y_type > label > select { - direction: rtl; - } - .gr-radio, .gr-checkbox{ - margin-left: 0.25em; - } +#txt2img_extra_networks, #img2img_extra_networks{ + margin-top: -1em; +} - /* automatically generated with few manual modifications */ - .performance .time { - margin-right: unset; - margin-left: 0; - } - .justify-center.overflow-x-scroll { - justify-content: right; - } - .justify-center.overflow-x-scroll button:first-of-type { - margin-left: unset; - margin-right: auto; - } - .justify-center.overflow-x-scroll button:last-of-type { - margin-right: unset; - margin-left: auto; - } - #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ - margin-right: unset; - margin-left: 8em; - } - #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ - right: unset; - left: 0; - } - .progressDiv .progress{ - padding: 0 0 0 8px; - text-align: left; - } - #lightboxModal{ - left: unset; - right: 0; - } - .modalPrev, .modalNext{ - border-radius: 3px 0 0 3px; - } - .modalNext { - right: unset; - left: 0; - border-radius: 0 3px 3px 0; - } - #imageARPreview{ - left:unset; - right:0px; - } - #txt2img_skip, #img2img_skip{ - right: unset; - left: 0px; - } - #context-menu{ - box-shadow:-1px 1px 2px #CE6400; - } - .gr-box > div > div > input.gr-text-input{ - right: unset; - left: 0.5em; - } +.extra-networks > div > [id *= '_extra_']{ + margin: 0.3em; } + +.extra-network-cards .nocards{ + margin: 1.25em 0.5em 0.5em 0.5em; +} + +.extra-network-cards .nocards h1{ + font-size: 1.5em; + margin-bottom: 1em; +} + +.extra-network-cards .nocards li{ + margin-left: 0.5em; +} + +.extra-network-cards .card{ + display: inline-block; + margin: 0.5em; + width: 16em; + height: 24em; + box-shadow: 0 0 5px rgba(128, 128, 128, 0.5); + border-radius: 0.2em; + position: relative; + + background-size: auto 100%; + background-position: center; + overflow: hidden; + cursor: pointer; + + background-image: url('./file=html/card-no-preview.png') +} + +.extra-network-cards .card:hover{ + box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35); +} + +.extra-network-cards .card .actions .additional{ + display: none; +} + +.extra-network-cards .card .actions{ + position: absolute; + bottom: 0; + left: 0; + right: 0; + padding: 0.5em; + color: white; + background: rgba(0,0,0,0.5); + box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5); + text-shadow: 0 0 0.2em black; +} + +.extra-network-cards .card .actions:hover{ + box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important; +} + +.extra-network-cards .card .actions .name{ + font-size: 1.7em; + font-weight: bold; + line-break: anywhere; +} + +.extra-network-cards .card .actions:hover .additional{ + display: block; +} + +.extra-network-cards .card ul{ + margin: 0.25em 0 0.75em 0.25em; + cursor: unset; +} + +.extra-network-cards .card ul a{ + cursor: pointer; +} + +.extra-network-cards .card ul a:hover{ + color: red; +} + diff --git a/webui.py b/webui.py index 865a7300..e8dd822a 100644 --- a/webui.py +++ b/webui.py @@ -9,16 +9,18 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook, errors +from modules import import_hook, errors, extra_networks +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path import torch + # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) -from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -84,10 +86,17 @@ def initialize(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks())) - shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: try: @@ -209,6 +218,15 @@ def webui(): modules.sd_models.list_models() + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if __name__ == "__main__": if cmd_opts.nowebui: -- cgit v1.2.3 From 855b9e3d1c5a1bd8c2d815d38a38bc7c410be5a8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 16:15:53 +0300 Subject: Lora support! update readme to reflect some recent changes --- README.md | 14 +- extensions-builtin/Lora/extra_networks_lora.py | 20 +++ extensions-builtin/Lora/lora.py | 198 ++++++++++++++++++++++ extensions-builtin/Lora/scripts/lora_script.py | 30 ++++ extensions-builtin/Lora/ui_extra_networks_lora.py | 35 ++++ modules/extra_networks_hypernet.py | 2 +- modules/script_callbacks.py | 15 ++ modules/ui_extra_networks.py | 2 +- webui.py | 2 + 9 files changed, 314 insertions(+), 4 deletions(-) create mode 100644 extensions-builtin/Lora/extra_networks_lora.py create mode 100644 extensions-builtin/Lora/lora.py create mode 100644 extensions-builtin/Lora/scripts/lora_script.py create mode 100644 extensions-builtin/Lora/ui_extra_networks_lora.py (limited to 'webui.py') diff --git a/README.md b/README.md index 1ac794e8..9c0cd1ef 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ A browser interface based on Gradio library for Stable Diffusion. - Possible to change defaults/mix/max/step values for UI elements via text config - Tiling support, a checkbox to create images that can be tiled like textures - Progress bar and live image generation preview + - Can use a separate neural network to produce previews with almost none VRAM or compute requirement - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image - Styles, a way to save part of prompt and easily apply them via dropdown later - Variations, a way to generate same image but with tiny differences @@ -75,13 +76,22 @@ A browser interface based on Gradio library for Stable Diffusion. - hypernetworks and embeddings options - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime) - Clip skip -- Use Hypernetworks -- Use VAEs +- Hypernetworks +- Loras (same as Hypernetworks but more pretty) +- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt. +- Can select to load a different VAE from settings screen - Estimated completion time in progress bar - API - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions +- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions +- Now without any bad letters! +- Load checkpoints in safetensors format +- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 +- Now with a license! +- Reorder elements in the UI from settings screen +- ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py new file mode 100644 index 00000000..8f2e753e --- /dev/null +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -0,0 +1,20 @@ +from modules import extra_networks +import lora + +class ExtraNetworkLora(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('lora') + + def activate(self, p, params_list): + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + lora.load_loras(names, multipliers) + + def deactivate(self, p): + pass diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py new file mode 100644 index 00000000..7a3ad9a9 --- /dev/null +++ b/extensions-builtin/Lora/lora.py @@ -0,0 +1,198 @@ +import glob +import os +import re +import torch + +from modules import shared, devices, sd_models + +re_digits = re.compile(r"\d+") +re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") +re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") +re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") +re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") + + +def convert_diffusers_name_to_compvis(key): + def match(match_list, regex): + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, re_unet_down_blocks): + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_unet_mid_blocks): + return f"diffusion_model_middle_block_1_{m[1]}" + + if match(m, re_unet_up_blocks): + return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_text_block): + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + return key + + +class LoraOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + + +class LoraModule: + def __init__(self, name): + self.name = name + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + +class LoraUpDownModule: + def __init__(self): + self.up = None + self.down = None + + +def assign_lora_names_to_compvis_modules(sd_model): + lora_layer_mapping = {} + + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + for name, module in shared.sd_model.model.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + sd_model.lora_layer_mapping = lora_layer_mapping + + +def load_lora(name, filename): + lora = LoraModule(name) + lora.mtime = os.path.getmtime(filename) + + sd = sd_models.read_state_dict(filename) + + keys_failed_to_match = [] + + for key_diffusers, weight in sd.items(): + fullkey = convert_diffusers_name_to_compvis(key_diffusers) + key, lora_key = fullkey.split(".", 1) + + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: + keys_failed_to_match.append(key_diffusers) + continue + + if type(sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.Conv2d: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.device, dtype=devices.dtype) + + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "lora_up.weight": + lora_module.up = module + elif lora_key == "lora_down.weight": + lora_module.down = module + else: + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + + if len(keys_failed_to_match) > 0: + print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") + + return lora + + +def load_loras(names, multipliers=None): + already_loaded = {} + + for lora in loaded_loras: + if lora.name in names: + already_loaded[lora.name] = lora + + loaded_loras.clear() + + loras_on_disk = [available_loras.get(name, None) for name in names] + if any([x is None for x in loras_on_disk]): + list_available_loras() + + loras_on_disk = [available_loras.get(name, None) for name in names] + + for i, name in enumerate(names): + lora = already_loaded.get(name, None) + + lora_on_disk = loras_on_disk[i] + if lora_on_disk is not None: + if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: + lora = load_lora(name, lora_on_disk.filename) + + if lora is None: + print(f"Couldn't find Lora with name {name}") + continue + + lora.multiplier = multipliers[i] if multipliers else 1.0 + loaded_loras.append(lora) + + +def lora_forward(module, input, res): + if len(loaded_loras) == 0: + return res + + lora_layer_name = getattr(module, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is not None: + res = res + module.up(module.down(input)) * lora.multiplier + + return res + + +def lora_Linear_forward(self, input): + return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + + +def lora_Conv2d_forward(self, input): + return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + + +def list_available_loras(): + available_loras.clear() + + os.makedirs(lora_dir, exist_ok=True) + + candidates = glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + + for filename in sorted(candidates): + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + + available_loras[name] = LoraOnDisk(name, filename) + + +lora_dir = os.path.join(shared.models_path, "Lora") +available_loras = {} +loaded_loras = [] + +list_available_loras() + diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py new file mode 100644 index 00000000..60b9eb64 --- /dev/null +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -0,0 +1,30 @@ +import torch + +import lora +import extra_networks_lora +import ui_extra_networks_lora +from modules import script_callbacks, ui_extra_networks, extra_networks + + +def unload(): + torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + + +def before_ui(): + ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) + extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) + + +if not hasattr(torch.nn, 'Linear_forward_before_lora'): + torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward + +if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): + torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward + +torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Conv2d.forward = lora.lora_Conv2d_forward + +script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) +script_callbacks.on_script_unloaded(unload) +script_callbacks.on_before_ui(before_ui) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py new file mode 100644 index 00000000..65397890 --- /dev/null +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -0,0 +1,35 @@ +import os +import lora + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Lora') + + def refresh(self): + lora.list_available_loras() + + def list_items(self): + for name, lora_on_disk in lora.available_loras.items(): + path, ext = os.path.splitext(lora_on_disk.filename) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "prompt": f"", + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [lora.lora_dir] + diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index 6a0c4ba8..ff279a1f 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -17,5 +17,5 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): hypernetwork.load_hypernetworks(names, multipliers) - def deactivate(p, self): + def deactivate(self, p): pass diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a9e19236..4bb45ec7 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -73,6 +73,7 @@ callback_map = dict( callbacks_image_grid=[], callbacks_infotext_pasted=[], callbacks_script_unloaded=[], + callbacks_before_ui=[], ) @@ -189,6 +190,14 @@ def script_unloaded_callback(): report_exception(c, 'script_unloaded') +def before_ui_callback(): + for c in reversed(callback_map['callbacks_before_ui']): + try: + c.callback() + except Exception: + report_exception(c, 'before_ui') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -313,3 +322,9 @@ def on_script_unloaded(callback): the script did should be reverted here""" add_callback(callback_map['callbacks_script_unloaded'], callback) + + +def on_before_ui(callback): + """register a function to be called before the UI is created.""" + + add_callback(callback_map['callbacks_before_ui'], callback) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 253e90f7..796e879c 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ extra_pages = [] def register_page(page): - """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions""" + """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" extra_pages.append(page) diff --git a/webui.py b/webui.py index e8dd822a..88d04840 100644 --- a/webui.py +++ b/webui.py @@ -165,6 +165,8 @@ def webui(): if shared.opts.clean_temp_dir_at_start: ui_tempdir.cleanup_tmpdr() + modules.script_callbacks.before_ui_callback() + shared.demo = modules.ui.create_ui() app, local_url, share_url = shared.demo.launch( -- cgit v1.2.3 From 63b824376c49013880ff44c260ea426e2899511e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 18:47:54 +0300 Subject: add --gradio-queue option to enable gradio queue --- modules/shared.py | 2 ++ webui.py | 3 +++ 2 files changed, 5 insertions(+) (limited to 'webui.py') diff --git a/modules/shared.py b/modules/shared.py index 72fb1934..52bbb807 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -100,6 +100,8 @@ parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS o parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) +parser.add_argument("--gradio-queue", action='store_true', help="Uses gradio queue; experimental option; breaks restart UI button") + script_loading.preload_extensions(extensions.extensions_dir, parser) script_loading.preload_extensions(extensions.extensions_builtin_dir, parser) diff --git a/webui.py b/webui.py index 88d04840..d235da74 100644 --- a/webui.py +++ b/webui.py @@ -169,6 +169,9 @@ def webui(): shared.demo = modules.ui.create_ui() + if cmd_opts.gradio_queue: + shared.demo.queue(64) + app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=server_name, -- cgit v1.2.3 From 68303c96e5ab31576a8238a24bf5b6191cf16ed1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 22 Jan 2023 15:38:39 +0300 Subject: split oversize extras.py to postprocessing.py --- modules/extras.py | 217 +------------------------------------- modules/postprocessing.py | 257 +--------------------------------------------- modules/ui.py | 10 +- modules/ui_components.py | 7 ++ webui.py | 1 - 5 files changed, 18 insertions(+), 474 deletions(-) (limited to 'webui.py') diff --git a/modules/extras.py b/modules/extras.py index 385430dc..f04ddfc2 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -1,231 +1,16 @@ -from __future__ import annotations -import math import os import re -import sys -import traceback import shutil -import numpy as np -from PIL import Image import torch import tqdm -from typing import Callable, List, OrderedDict, Tuple -from functools import partial -from dataclasses import dataclass - -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae -from modules.shared import opts -import modules.gfpgan_model +from modules import shared, images, sd_models, sd_vae from modules.ui import plaintext_to_html -import modules.codeformer_model import gradio as gr import safetensors.torch -class LruCache(OrderedDict): - @dataclass(frozen=True) - class Key: - image_hash: int - info_hash: int - args_hash: int - - @dataclass - class Value: - image: Image.Image - info: str - - def __init__(self, max_size: int = 5, *args, **kwargs): - super().__init__(*args, **kwargs) - self._max_size = max_size - - def get(self, key: LruCache.Key) -> LruCache.Value: - ret = super().get(key) - if ret is not None: - self.move_to_end(key) # Move to end of eviction list - return ret - - def put(self, key: LruCache.Key, value: LruCache.Value) -> None: - self[key] = value - while len(self) > self._max_size: - self.popitem(last=False) - - -cached_images: LruCache = LruCache(max_size=5) - - -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): - devices.torch_gc() - - shared.state.begin() - shared.state.job = 'extras' - - imageArr = [] - # Also keep track of original file names - imageNameArr = [] - outputs = [] - - if extras_mode == 1: - #convert file to pillow image - for img in image_folder: - image = Image.open(img) - imageArr.append(image) - imageNameArr.append(os.path.splitext(img.orig_name)[0]) - elif extras_mode == 2: - assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled' - - if input_dir == '': - return outputs, "Please select an input directory.", '' - image_list = shared.listfiles(input_dir) - for img in image_list: - try: - image = Image.open(img) - except Exception: - continue - imageArr.append(image) - imageNameArr.append(img) - else: - imageArr.append(image) - imageNameArr.append(None) - - if extras_mode == 2 and output_dir != '': - outpath = output_dir - else: - outpath = opts.outdir_samples or opts.outdir_extras_samples - - # Extra operation definitions - - def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-gfpgan' - restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) - res = Image.fromarray(restored_img) - - if gfpgan_visibility < 1.0: - res = Image.blend(image, res, gfpgan_visibility) - - info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" - return (res, info) - - def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - shared.state.job = 'extras-codeformer' - restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight) - res = Image.fromarray(restored_img) - - if codeformer_visibility < 1.0: - res = Image.blend(image, res, codeformer_visibility) - - info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" - return (res, info) - - def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): - shared.state.job = 'extras-upscale' - upscaler = shared.sd_upscalers[scaler_index] - res = upscaler.scaler.upscale(image, resize, upscaler.data_path) - if mode == 1 and crop: - cropped = Image.new("RGB", (resize_w, resize_h)) - cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2)) - res = cropped - return res - - def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: - # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text - nonlocal upscaling_resize - if resize_mode == 1: - upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) - crop_info = " (crop)" if upscaling_crop else "" - info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" - return (image, info) - - @dataclass - class UpscaleParams: - upscaler_idx: int - blend_alpha: float - - def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: - blended_result: Image.Image = None - image_hash: str = hash(np.array(image.getdata()).tobytes()) - for upscaler in params: - upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, - upscaling_resize_w, upscaling_resize_h, upscaling_crop) - cache_key = LruCache.Key(image_hash=image_hash, - info_hash=hash(info), - args_hash=hash(upscale_args)) - cached_entry = cached_images.get(cache_key) - if cached_entry is None: - res = upscale(image, *upscale_args) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" - cached_images.put(cache_key, LruCache.Value(image=res, info=info)) - else: - res, info = cached_entry.image, cached_entry.info - - if blended_result is None: - blended_result = res - else: - blended_result = Image.blend(blended_result, res, upscaler.blend_alpha) - return (blended_result, info) - - # Build a list of operations to run - facefix_ops: List[Callable] = [] - facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else [] - facefix_ops += [run_codeformer] if codeformer_visibility > 0 else [] - - upscale_ops: List[Callable] = [] - upscale_ops += [run_prepare_crop] if resize_mode == 1 else [] - - if upscaling_resize != 0: - step_params: List[UpscaleParams] = [] - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0)) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility)) - - upscale_ops.append(partial(run_upscalers_blend, step_params)) - - extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops) - - for image, image_name in zip(imageArr, imageNameArr): - if image is None: - return outputs, "Please select an input image.", '' - - shared.state.textinfo = f'Processing image {image_name}' - - existing_pnginfo = image.info or {} - - image = image.convert("RGB") - info = "" - # Run each operation on each image - for op in extras_ops: - image, info = op(image, info) - - if opts.use_original_name_batch and image_name is not None: - basename = os.path.splitext(os.path.basename(image_name))[0] - else: - basename = '' - - if opts.enable_pnginfo: # append info before save - image.info = existing_pnginfo - image.info["extras"] = info - - if save_output: - # Add upscaler name as a suffix. - suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else "" - # Add second upscaler if applicable. - if suffix and extras_upscaler_2 and extras_upscaler_2_visibility: - suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}" - - images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, - no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix) - - if extras_mode != 2 or show_extras_results : - outputs.append(image) - - devices.torch_gc() - - return outputs, plaintext_to_html(info), '' - -def clear_cache(): - cached_images.clear() - def run_pnginfo(image): if image is None: diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 385430dc..cb85720b 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -1,28 +1,18 @@ from __future__ import annotations -import math import os -import re -import sys -import traceback -import shutil import numpy as np from PIL import Image -import torch -import tqdm - from typing import Callable, List, OrderedDict, Tuple from functools import partial from dataclasses import dataclass -from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae +from modules import shared, images, devices, ui_components from modules.shared import opts import modules.gfpgan_model -from modules.ui import plaintext_to_html import modules.codeformer_model -import gradio as gr -import safetensors.torch + class LruCache(OrderedDict): @dataclass(frozen=True) @@ -55,7 +45,7 @@ class LruCache(OrderedDict): cached_images: LruCache = LruCache(max_size=5) -def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): +def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True): devices.torch_gc() shared.state.begin() @@ -221,246 +211,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ devices.torch_gc() - return outputs, plaintext_to_html(info), '' + return outputs, ui_components.plaintext_to_html(info), '' + def clear_cache(): cached_images.clear() - -def run_pnginfo(image): - if image is None: - return '', '', '' - - geninfo, items = images.read_info_from_image(image) - items = {**{'parameters': geninfo}, **items} - - info = '' - for key, text in items.items(): - info += f""" -
    -

    {plaintext_to_html(str(key))}

    -

    {plaintext_to_html(str(text))}

    -
    -""".strip()+"\n" - - if len(info) == 0: - message = "Nothing found in the image." - info = f"

    {message}

    " - - return '', geninfo, info - - -def create_config(ckpt_result, config_source, a, b, c): - def config(x): - res = sd_models.find_checkpoint_config(x) if x else None - return res if res != shared.sd_default_config else None - - if config_source == 0: - cfg = config(a) or config(b) or config(c) - elif config_source == 1: - cfg = config(b) - elif config_source == 2: - cfg = config(c) - else: - cfg = None - - if cfg is None: - return - - filename, _ = os.path.splitext(ckpt_result) - checkpoint_filename = filename + ".yaml" - - print("Copying config:") - print(" from:", cfg) - print(" to:", checkpoint_filename) - shutil.copyfile(cfg, checkpoint_filename) - - -checkpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - - -def to_half(tensor, enable): - if enable and tensor.dtype == torch.float: - return tensor.half() - - return tensor - - -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights): - shared.state.begin() - shared.state.job = 'model-merge' - - def fail(message): - shared.state.textinfo = message - shared.state.end() - return [*[gr.update() for _ in range(4)], message] - - def weighted_sum(theta0, theta1, alpha): - return ((1 - alpha) * theta0) + (alpha * theta1) - - def get_difference(theta1, theta2): - return theta1 - theta2 - - def add_difference(theta0, theta1_2_diff, alpha): - return theta0 + (alpha * theta1_2_diff) - - def filename_weighted_sum(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - Ma = round(1 - multiplier, 2) - Mb = round(multiplier, 2) - - return f"{Ma}({a}) + {Mb}({b})" - - def filename_add_difference(): - a = primary_model_info.model_name - b = secondary_model_info.model_name - c = tertiary_model_info.model_name - M = round(multiplier, 2) - - return f"{a} + {M}({b} - {c})" - - def filename_nothing(): - return primary_model_info.model_name - - theta_funcs = { - "Weighted sum": (filename_weighted_sum, None, weighted_sum), - "Add difference": (filename_add_difference, get_difference, add_difference), - "No interpolation": (filename_nothing, None, None), - } - filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] - shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) - - if not primary_model_name: - return fail("Failed: Merging requires a primary model.") - - primary_model_info = sd_models.checkpoints_list[primary_model_name] - - if theta_func2 and not secondary_model_name: - return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None - - if theta_func1 and not tertiary_model_name: - return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - - tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None - - result_is_inpainting_model = False - - if theta_func2: - shared.state.textinfo = f"Loading B" - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') - else: - theta_1 = None - - if theta_func1: - shared.state.textinfo = f"Loading C" - print(f"Loading {tertiary_model_info.filename}...") - theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') - - shared.state.textinfo = 'Merging B and C' - shared.state.sampling_steps = len(theta_1.keys()) - for key in tqdm.tqdm(theta_1.keys()): - if key in checkpoint_dict_skip_on_merge: - continue - - if 'model' in key: - if key in theta_2: - t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) - theta_1[key] = theta_func1(theta_1[key], t2) - else: - theta_1[key] = torch.zeros_like(theta_1[key]) - - shared.state.sampling_step += 1 - del theta_2 - - shared.state.nextjob() - - shared.state.textinfo = f"Loading {primary_model_info.filename}..." - print(f"Loading {primary_model_info.filename}...") - theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') - - print("Merging...") - shared.state.textinfo = 'Merging A and B' - shared.state.sampling_steps = len(theta_0.keys()) - for key in tqdm.tqdm(theta_0.keys()): - if theta_1 and 'model' in key and key in theta_1: - - if key in checkpoint_dict_skip_on_merge: - continue - - a = theta_0[key] - b = theta_1[key] - - # this enables merging an inpainting model (A) with another one (B); - # where normal model would have 4 channels, for latenst space, inpainting model would - # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 - if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]: - if a.shape[1] == 4 and b.shape[1] == 9: - raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.") - - assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" - - theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier) - result_is_inpainting_model = True - else: - theta_0[key] = theta_func2(a, b, multiplier) - - theta_0[key] = to_half(theta_0[key], save_as_half) - - shared.state.sampling_step += 1 - - del theta_1 - - bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) - if bake_in_vae_filename is not None: - print(f"Baking in VAE from {bake_in_vae_filename}") - shared.state.textinfo = 'Baking in VAE' - vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - - for key in vae_dict.keys(): - theta_0_key = 'first_stage_model.' + key - if theta_0_key in theta_0: - theta_0[theta_0_key] = to_half(vae_dict[key], save_as_half) - - del vae_dict - - if save_as_half and not theta_func2: - for key in theta_0.keys(): - theta_0[key] = to_half(theta_0[key], save_as_half) - - if discard_weights: - regex = re.compile(discard_weights) - for key in list(theta_0): - if re.search(regex, key): - theta_0.pop(key, None) - - ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - - filename = filename_generator() if custom_name == '' else custom_name - filename += ".inpainting" if result_is_inpainting_model else "" - filename += "." + checkpoint_format - - output_modelname = os.path.join(ckpt_dir, filename) - - shared.state.nextjob() - shared.state.textinfo = "Saving" - print(f"Saving to {output_modelname}...") - - _, extension = os.path.splitext(output_modelname) - if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) - else: - torch.save(theta_0, output_modelname) - - sd_models.list_models() - - create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - - print(f"Checkpoint saved to {output_modelname}.") - shared.state.textinfo = "Checkpoint saved" - shared.state.end() - - return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] diff --git a/modules/ui.py b/modules/ui.py index eb4b7e6b..4116e167 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, postprocessing, ui_components from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -95,8 +95,8 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): - text = "

    " + "
    \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

    " - return text + return ui_components.plaintext_to_html(text) + def send_gradio_gallery_to_image(x): if len(x) == 0: @@ -1152,7 +1152,7 @@ def create_ui(): result_images, html_info_x, html_info, html_log = create_output_panel("extras", opts.outdir_extras_samples) submit.click( - fn=wrap_gradio_gpu_call(modules.extras.run_extras, extra_outputs=[None, '']), + fn=wrap_gradio_gpu_call(postprocessing.run_postprocessing, extra_outputs=[None, '']), _js="get_extras_tab_index", inputs=[ dummy_component, @@ -1183,7 +1183,7 @@ def create_ui(): parameters_copypaste.add_paste_fields("extras", extras_image, None) extras_image.change( - fn=modules.extras.clear_cache, + fn=postprocessing.clear_cache, inputs=[], outputs=[] ) diff --git a/modules/ui_components.py b/modules/ui_components.py index 46324425..989cc87b 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -1,3 +1,5 @@ +import html + import gradio as gr @@ -47,3 +49,8 @@ class FormColorPicker(gr.ColorPicker, gr.components.FormComponent): def get_block_name(self): return "colorpicker" + + +def plaintext_to_html(text): + text = "

    " + "
    \n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

    " + return text diff --git a/webui.py b/webui.py index d235da74..7cf5885e 100644 --- a/webui.py +++ b/webui.py @@ -22,7 +22,6 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks import modules.codeformer_model as codeformer -import modules.extras import modules.face_restoration import modules.gfpgan_model as gfpgan import modules.img2img -- cgit v1.2.3 From 7ff1ef77dd22f7b38612f91b389237a5dbef2474 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 17:17:31 +0300 Subject: add a message about new torch/xformers version and a way to upgrade by specifying a commandline flag --- launch.py | 3 ++- webui.py | 26 ++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) (limited to 'webui.py') diff --git a/launch.py b/launch.py index 51c322c0..e7a0b50c 100644 --- a/launch.py +++ b/launch.py @@ -208,6 +208,7 @@ def prepare_environment(): sys.argv, _ = extract_arg(sys.argv, '-f') sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') + sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch') sys.argv, update_check = extract_arg(sys.argv, '--update-check') sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests') sys.argv, skip_install = extract_arg(sys.argv, '--skip-install') @@ -219,7 +220,7 @@ def prepare_environment(): print(f"Python {sys.version}") print(f"Commit hash: {commit}") - if not is_installed("torch") or not is_installed("torchvision"): + if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch") if not skip_torch_cuda_test: diff --git a/webui.py b/webui.py index 7cf5885e..bc2baeab 100644 --- a/webui.py +++ b/webui.py @@ -8,6 +8,7 @@ import re from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware +from packaging import version from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion @@ -49,7 +50,32 @@ else: server_name = "0.0.0.0" if cmd_opts.listen else None +def check_versions(): + expected_torch_version = "1.13.1" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + errors.print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded. + """.strip()) + + expected_xformers_version = "0.0.16rc425" + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + errors.print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + """.strip()) + + def initialize(): + check_versions() + extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) -- cgit v1.2.3 From 865af20d8a4a823df3c950f5c9c9092a541bc57a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 23 Jan 2023 21:28:59 +0300 Subject: suppress A matching Triton is not available message you can all now stop worrying about it --- webui.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index bc2baeab..e1565a8d 100644 --- a/webui.py +++ b/webui.py @@ -1,6 +1,5 @@ import os import sys -import threading import time import importlib import signal @@ -10,6 +9,9 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from packaging import version +import logging +logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call -- cgit v1.2.3 From 6b3981c0685cd1df750df4eb51823f1cfd70c6d5 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 18:00:09 +0100 Subject: clean up unused script_path imports --- modules/codeformer_model.py | 2 +- modules/generation_parameters_copypaste.py | 2 +- webui.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) (limited to 'webui.py') diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ab40d842..01fb7bd8 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -8,7 +8,7 @@ import torch import modules.face_restoration import modules.shared from modules import shared, devices, modelloader -from modules.paths import script_path, models_path +from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 35f72808..773c5c0e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re from pathlib import Path import gradio as gr -from modules.paths import data_path, script_path +from modules.paths import data_path from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image diff --git a/webui.py b/webui.py index e1565a8d..41f32f5c 100644 --- a/webui.py +++ b/webui.py @@ -15,7 +15,6 @@ logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not from modules import import_hook, errors, extra_networks from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call -from modules.paths import script_path import torch -- cgit v1.2.3