From 339b5315700a469f4a9f0d5afc08ca2aca60c579 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 27 May 2023 15:47:33 +0300 Subject: custom unet support --- modules/processing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 29a3743f..b75f2515 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -674,6 +674,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN": sd_vae_approx.model() + sd_unet.apply_unet() + if state.job_count == -1: state.job_count = p.n_iter -- cgit v1.2.3 From 00dfe27f59727407c5b408a80ff2a262934df495 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 08:54:13 +0300 Subject: Add & use modules.errors.print_error where currently printing exception info by hand --- extensions-builtin/LDSR/scripts/ldsr_model.py | 7 ++--- extensions-builtin/ScuNET/scripts/scunet_model.py | 6 ++-- modules/api/api.py | 7 +++-- modules/call_queue.py | 22 ++++++-------- modules/codeformer_model.py | 10 +++---- modules/config_states.py | 12 +++----- modules/errors.py | 16 +++++++++++ modules/extensions.py | 10 +++---- modules/gfpgan_model.py | 6 ++-- modules/hypernetworks/hypernetwork.py | 14 ++++----- modules/images.py | 9 ++---- modules/interrogate.py | 5 ++-- modules/launch_utils.py | 7 +++-- modules/localization.py | 6 ++-- modules/processing.py | 2 +- modules/realesrgan_model.py | 14 ++++----- modules/safe.py | 26 +++++++++-------- modules/script_callbacks.py | 9 +++--- modules/script_loading.py | 7 ++--- modules/scripts.py | 35 ++++++++--------------- modules/sd_hijack_optimizations.py | 6 ++-- modules/textual_inversion/textual_inversion.py | 9 ++---- modules/ui.py | 10 +++---- modules/ui_extensions.py | 9 ++---- scripts/prompts_from_file.py | 6 ++-- 25 files changed, 117 insertions(+), 153 deletions(-) (limited to 'modules/processing.py') diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index c4da79f3..95f1669d 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -1,9 +1,8 @@ import os -import sys -import traceback from basicsr.utils.download_util import load_file_from_url +from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR from modules import shared, script_callbacks @@ -51,10 +50,8 @@ class UpscalerLDSR(Upscaler): try: return LDSR(model, yaml) - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error importing LDSR", exc_info=True) return None def do_upscale(self, img, path): diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 45d9297b..dd1b822e 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,6 +1,5 @@ import os.path import sys -import traceback import PIL.Image import numpy as np @@ -12,6 +11,8 @@ from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader, script_callbacks from scunet_model_arch import SCUNet as net + +from modules.errors import print_error from modules.shared import opts @@ -38,8 +39,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) scalers.append(scaler_data) except Exception: - print(f"Error loading ScuNET model: {file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading ScuNET model: {file}", exc_info=True) if add_model2: scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) scalers.append(scaler_data2) diff --git a/modules/api/api.py b/modules/api/api.py index 6a456861..79ce9228 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -16,6 +16,7 @@ from secrets import compare_digest import modules.shared as shared from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing from modules.api import models +from modules.errors import print_error from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding @@ -108,7 +109,6 @@ def api_middleware(app: FastAPI): from rich.console import Console console = Console() except Exception: - import traceback rich_available = False @app.middleware("http") @@ -139,11 +139,12 @@ def api_middleware(app: FastAPI): "errors": str(e), } if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions - print(f"API error: {request.method}: {request.url} {err}") + message = f"API error: {request.method}: {request.url} {err}" if rich_available: + print(message) console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: - traceback.print_exc() + print_error(message, exc_info=True) return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) @app.middleware("http") diff --git a/modules/call_queue.py b/modules/call_queue.py index 447bb764..dba2a9b4 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,10 +1,9 @@ import html -import sys import threading -import traceback import time from modules import shared, progress +from modules.errors import print_error queue_lock = threading.Lock() @@ -56,16 +55,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): try: res = list(func(*args, **kwargs)) except Exception as e: - # When printing out our debug argument list, do not print out more than a MB of text - max_debug_str_len = 131072 # (1024*1024)/8 - - print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {args} {kwargs}" - print(argStr[:max_debug_str_len], file=sys.stderr) - if len(argStr) > max_debug_str_len: - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) - - print(traceback.format_exc(), file=sys.stderr) + # When printing out our debug argument list, + # do not print out more than a 100 KB of text + max_debug_str_len = 131072 + message = "Error completing request" + arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len] + if len(arg_str) > max_debug_str_len: + arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" + print_error(f"{message}\n{arg_str}", exc_info=True) shared.state.job = "" shared.state.job_count = 0 @@ -108,4 +105,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): return tuple(res) return f - diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ececdbae..76143e9f 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -1,6 +1,4 @@ import os -import sys -import traceback import cv2 import torch @@ -8,6 +6,7 @@ import torch import modules.face_restoration import modules.shared from modules import shared, devices, modelloader +from modules.errors import print_error from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes @@ -105,8 +104,8 @@ def setup_model(dirname): restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output torch.cuda.empty_cache() - except Exception as error: - print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) + except Exception: + print_error('Failed inference for CodeFormer', exc_info=True) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') @@ -135,7 +134,6 @@ def setup_model(dirname): shared.face_restorers.append(codeformer) except Exception: - print("Error setting up CodeFormer:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error setting up CodeFormer", exc_info=True) # sys.path = stored_sys_path diff --git a/modules/config_states.py b/modules/config_states.py index db65bcdb..faeaf28b 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c """ import os -import sys -import traceback import json import time import tqdm @@ -14,6 +12,7 @@ from collections import OrderedDict import git from modules import shared, extensions +from modules.errors import print_error from modules.paths_internal import script_path, config_states_dir @@ -53,8 +52,7 @@ def get_webui_config(): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print(f"Error reading webui git info from {script_path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading webui git info from {script_path}", exc_info=True) webui_remote = None webui_commit_hash = None @@ -134,8 +132,7 @@ def restore_webui_config(config): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print(f"Error reading webui git info from {script_path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading webui git info from {script_path}", exc_info=True) return try: @@ -143,8 +140,7 @@ def restore_webui_config(config): webui_repo.git.reset(webui_commit_hash, hard=True) print(f"* Restored webui to commit {webui_commit_hash}.") except Exception: - print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error restoring webui to commit{webui_commit_hash}") def restore_extension_config(config): diff --git a/modules/errors.py b/modules/errors.py index da4694f8..41d8dc93 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -1,7 +1,23 @@ import sys +import textwrap import traceback +def print_error( + message: str, + *, + exc_info: bool = False, +) -> None: + """ + Print an error message to stderr, with optional traceback. + """ + for line in message.splitlines(): + print("***", line, file=sys.stderr) + if exc_info: + print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr) + print("---") + + def print_error_explanation(message): lines = message.strip().split("\n") max_len = max([len(x) for x in lines]) diff --git a/modules/extensions.py b/modules/extensions.py index 624832a0..369d2584 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,11 +1,10 @@ import os -import sys import threading -import traceback import git from modules import shared +from modules.errors import print_error from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 extensions = [] @@ -56,8 +55,7 @@ class Extension: if os.path.exists(os.path.join(self.path, ".git")): repo = git.Repo(self.path) except Exception: - print(f"Error reading github repository info from {self.path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error reading github repository info from {self.path}", exc_info=True) if repo is None or repo.bare: self.remote = None @@ -72,8 +70,8 @@ class Extension: self.commit_hash = commit.hexsha self.version = self.commit_hash[:8] - except Exception as ex: - print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) + except Exception: + print_error(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) self.remote = None self.have_info_from_repo = True diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 0131dea4..d2f647fe 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,12 +1,11 @@ import os -import sys -import traceback import facexlib import gfpgan import modules.face_restoration from modules import paths, shared, devices, modelloader +from modules.errors import print_error model_dir = "GFPGAN" user_path = None @@ -112,5 +111,4 @@ def setup_model(dirname): shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: - print("Error setting up GFPGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error setting up GFPGAN", exc_info=True) diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 570b5603..fcc1ef20 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -2,8 +2,6 @@ import datetime import glob import html import os -import sys -import traceback import inspect import modules.textual_inversion.dataset @@ -12,6 +10,7 @@ import tqdm from einops import rearrange, repeat from ldm.util import default from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint +from modules.errors import print_error from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -325,17 +324,14 @@ def load_hypernetwork(name): if path is None: return None - hypernetwork = Hypernetwork() - try: + hypernetwork = Hypernetwork() hypernetwork.load(path) + return hypernetwork except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading hypernetwork {path}", exc_info=True) return None - return hypernetwork - def load_hypernetworks(names, multipliers=None): already_loaded = {} @@ -770,7 +766,7 @@ Last saved image: {html.escape(last_saved_image)}

""" except Exception: - print(traceback.format_exc(), file=sys.stderr) + print_error("Exception in training hypernetwork", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/images.py b/modules/images.py index e21e554c..69151bec 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,6 +1,4 @@ import datetime -import sys -import traceback import pytz import io @@ -18,6 +16,7 @@ import json import hashlib from modules import sd_samplers, shared, script_callbacks, errors +from modules.errors import print_error from modules.paths_internal import roboto_ttf_file from modules.shared import opts @@ -464,8 +463,7 @@ class FilenameGenerator: replacement = fun(self, *pattern_args) except Exception: replacement = None - print(f"Error adding [{pattern}] to filename", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error adding [{pattern}] to filename", exc_info=True) if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT: continue @@ -697,8 +695,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]} Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" except Exception: - print("Error parsing NovelAI image generation parameters:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error parsing NovelAI image generation parameters", exc_info=True) return geninfo, items diff --git a/modules/interrogate.py b/modules/interrogate.py index 111b1322..d36e1a5a 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -1,6 +1,5 @@ import os import sys -import traceback from collections import namedtuple from pathlib import Path import re @@ -12,6 +11,7 @@ from torchvision import transforms from torchvision.transforms.functional import InterpolationMode from modules import devices, paths, shared, lowvram, modelloader, errors +from modules.errors import print_error blip_image_eval_size = 384 clip_model_name = 'ViT-L/14' @@ -216,8 +216,7 @@ class InterrogateModels: res += f", {match}" except Exception: - print("Error interrogating", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error interrogating", exc_info=True) res += "" self.unload() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 35a52310..22edc106 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -8,6 +8,7 @@ import json from functools import lru_cache from modules import cmd_args +from modules.errors import print_error from modules.paths_internal import script_path, extensions_dir args, _ = cmd_args.parser.parse_known_args() @@ -188,7 +189,7 @@ def run_extension_installer(extension_dir): print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) except Exception as e: - print(e, file=sys.stderr) + print_error(str(e)) def list_extensions(settings_file): @@ -198,8 +199,8 @@ def list_extensions(settings_file): if os.path.isfile(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) - except Exception as e: - print(e, file=sys.stderr) + except Exception: + print_error("Could not load settings", exc_info=True) disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') diff --git a/modules/localization.py b/modules/localization.py index ee9c65e7..9a1df343 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,8 +1,7 @@ import json import os -import sys -import traceback +from modules.errors import print_error localizations = {} @@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str: with open(fn, "r", encoding="utf8") as file: data = json.load(file) except Exception: - print(f"Error loading localization from {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading localization from {fn}", exc_info=True) return f"window.localization = {json.dumps(data)}" diff --git a/modules/processing.py b/modules/processing.py index b75f2515..5c9bcce8 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,4 +1,5 @@ import json +import logging import math import os import sys @@ -23,7 +24,6 @@ import modules.images as images import modules.styles import modules.sd_models as sd_models import modules.sd_vae as sd_vae -import logging from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 99983678..c8d0c64f 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,12 +1,11 @@ import os -import sys -import traceback import numpy as np from PIL import Image from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer +from modules.errors import print_error from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts from modules import modelloader @@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler): self.scalers.append(scaler) except Exception: - print("Error importing Real-ESRGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error importing Real-ESRGAN", exc_info=True) self.enable = False self.scalers = [] @@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler): info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) return info - except Exception as e: - print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + except Exception: + print_error("Error making Real-ESRGAN models list", exc_info=True) return None def load_models(self, _): @@ -135,5 +132,4 @@ def get_realesrgan_models(scaler): ] return models except Exception: - print("Error making Real-ESRGAN models list:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error making Real-ESRGAN models list", exc_info=True) diff --git a/modules/safe.py b/modules/safe.py index e8f50774..b596f565 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -2,8 +2,6 @@ import pickle import collections -import sys -import traceback import torch import numpy @@ -11,6 +9,8 @@ import _codecs import zipfile import re +from modules.errors import print_error + # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage @@ -136,17 +136,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): check_pt(filename, extra_handler) except pickle.UnpicklingError: - print(f"Error verifying pickled file from {filename}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) - print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) + print_error( + f"Error verifying pickled file from {filename}\n" + "-----> !!!! The file is most likely corrupted !!!! <-----\n" + "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", + exc_info=True, + ) return None - except Exception: - print(f"Error verifying pickled file from {filename}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) - print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) + print_error( + f"Error verifying pickled file from {filename}\n" + f"The file may be malicious, so the program is not going to read it.\n" + f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", + exc_info=True, + ) return None return unsafe_torch_load(filename, *args, **kwargs) @@ -190,4 +193,3 @@ with safe.Extra(handler): unsafe_torch_load = torch.load torch.load = load global_extra_handler = None - diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index d2728e12..6aa9c3b6 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,16 +1,15 @@ -import sys -import traceback -from collections import namedtuple import inspect +from collections import namedtuple from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks +from modules.errors import print_error + def report_exception(c, job): - print(f"Error executing callback {job} for {c.script}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error executing callback {job} for {c.script}", exc_info=True) class ImageSaveParams: diff --git a/modules/script_loading.py b/modules/script_loading.py index 57b15862..26efffcb 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,8 +1,8 @@ import os -import sys -import traceback import importlib.util +from modules.errors import print_error + def load_module(path): module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) @@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser): module.preload(parser) except Exception: - print(f"Error running preload() for {preload_script}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running preload() for {preload_script}", exc_info=True) diff --git a/modules/scripts.py b/modules/scripts.py index c902804b..a7168fd1 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,12 +1,12 @@ import os import re import sys -import traceback from collections import namedtuple import gradio as gr from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing +from modules.errors import print_error AlwaysVisible = object() @@ -264,8 +264,7 @@ def load_scripts(): register_scripts_from_module(script_module) except Exception: - print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading script: {scriptfile.filename}", exc_info=True) finally: sys.path = syspath @@ -280,11 +279,9 @@ def load_scripts(): def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: - res = func(*args, **kwargs) - return res + return func(*args, **kwargs) except Exception: - print(f"Error calling: {filename}/{funcname}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error calling: {filename}/{funcname}", exc_info=True) return default @@ -450,8 +447,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) except Exception: - print(f"Error running process: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running process: {script.filename}", exc_info=True) def before_process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -459,8 +455,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.before_process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running before_process_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running before_process_batch: {script.filename}", exc_info=True) def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -468,8 +463,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running process_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running process_batch: {script.filename}", exc_info=True) def postprocess(self, p, processed): for script in self.alwayson_scripts: @@ -477,8 +471,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess(p, processed, *script_args) except Exception: - print(f"Error running postprocess: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess: {script.filename}", exc_info=True) def postprocess_batch(self, p, images, **kwargs): for script in self.alwayson_scripts: @@ -486,8 +479,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch(p, *script_args, images=images, **kwargs) except Exception: - print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess_batch: {script.filename}", exc_info=True) def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: @@ -495,24 +487,21 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image(p, pp, *script_args) except Exception: - print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running postprocess_image: {script.filename}", exc_info=True) def before_component(self, component, **kwargs): for script in self.scripts: try: script.before_component(component, **kwargs) except Exception: - print(f"Error running before_component: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running before_component: {script.filename}", exc_info=True) def after_component(self, component, **kwargs): for script in self.scripts: try: script.after_component(component, **kwargs) except Exception: - print(f"Error running after_component: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error running after_component: {script.filename}", exc_info=True) def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 2ec0b049..fd186fa2 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,5 @@ from __future__ import annotations import math -import sys -import traceback import psutil import torch @@ -11,6 +9,7 @@ from ldm.util import default from einops import rearrange from modules import shared, errors, devices, sub_quadratic_attention +from modules.errors import print_error from modules.hypernetworks import hypernetwork import ldm.modules.attention @@ -140,8 +139,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: import xformers.ops shared.xformers_available = True except Exception: - print("Cannot import xformers", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Cannot import xformers", exc_info=True) def get_available_vram(): diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d489ed1e..a040a988 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,6 +1,4 @@ import os -import sys -import traceback from collections import namedtuple import torch @@ -16,6 +14,7 @@ from torch.utils.tensorboard import SummaryWriter from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint import modules.textual_inversion.dataset +from modules.errors import print_error from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay @@ -207,8 +206,7 @@ class EmbeddingDatabase: self.load_from_file(fullfn, fn) except Exception: - print(f"Error loading embedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error loading embedding {fn}", exc_info=True) continue def load_textual_inversion_embeddings(self, force_reload=False): @@ -632,8 +630,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: - print(traceback.format_exc(), file=sys.stderr) - pass + print_error("Error training embedding", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/ui.py b/modules/ui.py index 001b9792..1ad94f02 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -2,7 +2,6 @@ import json import mimetypes import os import sys -import traceback from functools import reduce import warnings @@ -14,6 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave +from modules.errors import print_error from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path, data_path @@ -231,9 +231,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: res = all_seeds[index if 0 <= index < len(all_seeds) else 0] except json.decoder.JSONDecodeError: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) + if gen_info_string: + print_error(f"Error parsing JSON generation info: {gen_info_string}") return [res, gr_show(False)] @@ -1753,8 +1752,7 @@ def create_ui(): try: results = modules.extras.run_modelmerger(*args) except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error("Error loading/saving model file", exc_info=True) modules.sd_models.list_models() # to remove the potentially missing models from the list return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 515ec262..cadf56be 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,10 +1,8 @@ import json import os.path -import sys import threading import time from datetime import datetime -import traceback import git @@ -14,6 +12,7 @@ import shutil import errno from modules import extensions, shared, paths, config_states +from modules.errors import print_error from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call @@ -46,8 +45,7 @@ def apply_and_restart(disable_list, update_list, disable_all): try: ext.fetch_and_reset_hard() except Exception: - print(f"Error getting updates for {ext.name}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error getting updates for {ext.name}", exc_info=True) shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all @@ -113,8 +111,7 @@ def check_updates(id_task, disable_list): if 'FETCH_HEAD' not in str(e): raise except Exception: - print(f"Error checking updates for {ext.name}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error checking updates for {ext.name}", exc_info=True) shared.state.nextjob() diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index b918a764..4dc24615 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -1,13 +1,12 @@ import copy import random -import sys -import traceback import shlex import modules.scripts as scripts import gradio as gr from modules import sd_samplers +from modules.errors import print_error from modules.processing import Processed, process_images from modules.shared import state @@ -136,8 +135,7 @@ class Script(scripts.Script): try: args = cmdargs(line) except Exception: - print(f"Error parsing line {line} as commandline:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + print_error(f"Error parsing line {line} as commandline", exc_info=True) args = {"prompt": line} else: args = {"prompt": line} -- cgit v1.2.3 From 4a449375a20267035402eed5d93074c2f0a91bc8 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 30 May 2023 01:07:35 +0900 Subject: fix get_conds_with_caching() --- modules/processing.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index b75f2515..395c851f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -321,14 +321,13 @@ class StableDiffusionProcessing: have been used before. The second element is where the previously computed result is stored. """ - - if cache[0] is not None and (required_prompts, steps) == cache[0]: + if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]: return cache[1] with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) - cache[0] = (required_prompts, steps) + cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) return cache[1] def setup_conds(self): -- cgit v1.2.3 From df02498d03e4296b7d7581aff69571a49be1d27a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 31 May 2023 22:40:09 +0300 Subject: add an option to show selected setting in main txt2img/img2img UI split some code from ui.py into ui_settings.py ui_gradio_edxtensions.py add before_process callback for scripts add ability for alwayson scripts to specify section and let user reorder those sections --- .../scripts/extra_options_section.py | 48 +++ modules/processing.py | 6 +- modules/scripts.py | 116 ++++--- modules/shared_items.py | 10 + modules/ui.py | 351 ++------------------- modules/ui_common.py | 23 ++ modules/ui_gradio_extensions.py | 69 ++++ modules/ui_settings.py | 263 +++++++++++++++ 8 files changed, 512 insertions(+), 374 deletions(-) create mode 100644 extensions-builtin/extra-options-section/scripts/extra_options_section.py create mode 100644 modules/ui_gradio_extensions.py create mode 100644 modules/ui_settings.py (limited to 'modules/processing.py') diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py new file mode 100644 index 00000000..17f84184 --- /dev/null +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -0,0 +1,48 @@ +import gradio as gr +from modules import scripts, shared, ui_components, ui_settings +from modules.ui_components import FormColumn + + +class ExtraOptionsSection(scripts.Script): + section = "extra_options" + + def __init__(self): + self.comps = None + self.setting_names = None + + def title(self): + return "Extra options" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + self.comps = [] + self.setting_names = [] + + with gr.Blocks() as interface: + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and len(shared.opts.extra_options) > 0 else gr.Group(), gr.Row(): + for setting_name in shared.opts.extra_options: + with FormColumn(): + comp = ui_settings.create_setting_component(setting_name) + + self.comps.append(comp) + self.setting_names.append(setting_name) + + def get_settings_values(): + return [ui_settings.get_value_for_setting(key) for key in self.setting_names] + + interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False) + + return self.comps + + def before_process(self, p, *args): + for name, value in zip(self.setting_names, args): + if name not in p.override_settings: + p.override_settings[name] = value + + +shared.options_templates.update(shared.options_section(('ui', "User interface"), { + "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(), + "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion") +})) diff --git a/modules/processing.py b/modules/processing.py index f628d88b..baa9b278 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -588,11 +588,15 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter def process_images(p: StableDiffusionProcessing) -> Processed: + if p.scripts is not None: + p.scripts.before_process(p) + stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} try: # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint - if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None: + override_checkpoint = p.override_settings.get('sd_model_checkpoint') + if override_checkpoint is not None and sd_models.checkpoint_alisases.get(override_checkpoint) is None: p.override_settings.pop('sd_model_checkpoint', None) sd_models.reload_model_weights() diff --git a/modules/scripts.py b/modules/scripts.py index 0970f38e..b901862d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -19,6 +19,9 @@ class Script: name = None """script's internal name derived from title""" + section = None + """name of UI section that the script's controls will be placed into""" + filename = None args_from = None args_to = None @@ -81,6 +84,15 @@ class Script: pass + def before_process(self, p, *args): + """ + This function is called very early before processing begins for AlwaysVisible scripts. + You can modify the processing object (p) here, inject hooks, etc. + args contains all values returned by components from ui() + """ + + pass + def process(self, p, *args): """ This function is called before processing begins for AlwaysVisible scripts. @@ -293,6 +305,7 @@ class ScriptRunner: self.titles = [] self.infotext_fields = [] self.paste_field_names = [] + self.inputs = [None] def initialize_scripts(self, is_img2img): from modules import scripts_auto_postprocessing @@ -320,69 +333,73 @@ class ScriptRunner: self.scripts.append(script) self.selectable_scripts.append(script) - def setup_ui(self): + def create_script_ui(self, script): import modules.api.models as api_models - self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] + script.args_from = len(self.inputs) + script.args_to = len(self.inputs) - inputs = [None] - inputs_alwayson = [True] + controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) - def create_script_ui(script, inputs, inputs_alwayson): - script.args_from = len(inputs) - script.args_to = len(inputs) + if controls is None: + return - controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) + script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() + api_args = [] - if controls is None: - return + for control in controls: + control.custom_script_source = os.path.basename(script.filename) - script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() - api_args = [] + arg_info = api_models.ScriptArg(label=control.label or "") - for control in controls: - control.custom_script_source = os.path.basename(script.filename) + for field in ("value", "minimum", "maximum", "step", "choices"): + v = getattr(control, field, None) + if v is not None: + setattr(arg_info, field, v) - arg_info = api_models.ScriptArg(label=control.label or "") + api_args.append(arg_info) - for field in ("value", "minimum", "maximum", "step", "choices"): - v = getattr(control, field, None) - if v is not None: - setattr(arg_info, field, v) + script.api_info = api_models.ScriptInfo( + name=script.name, + is_img2img=script.is_img2img, + is_alwayson=script.alwayson, + args=api_args, + ) - api_args.append(arg_info) + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields - script.api_info = api_models.ScriptInfo( - name=script.name, - is_img2img=script.is_img2img, - is_alwayson=script.alwayson, - args=api_args, - ) + if script.paste_field_names is not None: + self.paste_field_names += script.paste_field_names - if script.infotext_fields is not None: - self.infotext_fields += script.infotext_fields + self.inputs += controls + script.args_to = len(self.inputs) - if script.paste_field_names is not None: - self.paste_field_names += script.paste_field_names + def setup_ui_for_section(self, section, scriptlist=None): + if scriptlist is None: + scriptlist = self.alwayson_scripts - inputs += controls - inputs_alwayson += [script.alwayson for _ in controls] - script.args_to = len(inputs) + for script in scriptlist: + if script.alwayson and script.section != section: + continue - for script in self.alwayson_scripts: - with gr.Group() as group: - create_script_ui(script, inputs, inputs_alwayson) + with gr.Group(visible=script.alwayson) as group: + self.create_script_ui(script) script.group = group - dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") - inputs[0] = dropdown + def prepare_ui(self): + self.inputs = [None] - for script in self.selectable_scripts: - with gr.Group(visible=False) as group: - create_script_ui(script, inputs, inputs_alwayson) + def setup_ui(self): + self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] - script.group = group + self.setup_ui_for_section(None) + + dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") + self.inputs[0] = dropdown + + self.setup_ui_for_section(None, self.selectable_scripts) def select_script(script_index): selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None @@ -407,6 +424,7 @@ class ScriptRunner: ) self.script_load_ctr = 0 + def onload_script_visibility(params): title = params.get('Script', None) if title: @@ -417,10 +435,10 @@ class ScriptRunner: else: return gr.update(visible=False) - self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) ) - self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] ) + self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None')))) + self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts]) - return inputs + return self.inputs def run(self, p, *args): script_index = args[0] @@ -440,6 +458,14 @@ class ScriptRunner: return processed + def before_process(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.before_process(p, *script_args) + except Exception: + errors.report(f"Error running before_process: {script.filename}", exc_info=True) + def process(self, p): for script in self.alwayson_scripts: try: diff --git a/modules/shared_items.py b/modules/shared_items.py index 27bceb18..89792e88 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -55,5 +55,15 @@ ui_reorder_categories_builtin_items = [ def ui_reorder_categories(): + from modules import scripts + yield from ui_reorder_categories_builtin_items + + sections = {} + for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts: + if isinstance(script.section, str): + sections[script.section] = 1 + + yield from sections + yield "scripts" diff --git a/modules/ui.py b/modules/ui.py index 35563669..4e0cf776 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -6,15 +6,17 @@ from functools import reduce import warnings import gradio as gr -import gradio.routes import gradio.utils import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path, data_path +from modules.paths import script_path +from modules.ui_common import create_refresh_button +from modules.ui_gradio_extensions import reload_javascript + from modules.shared import opts, cmd_opts @@ -34,6 +36,8 @@ import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text import modules.extras +create_setting_component = ui_settings.create_setting_component + warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI @@ -366,25 +370,6 @@ def apply_setting(key, value): return getattr(opts, key) -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - def create_output_panel(tabname, outdir): return ui_common.create_output_panel(tabname, outdir) @@ -409,16 +394,6 @@ def ordered_ui_categories(): yield category -def get_value_for_setting(key): - value = getattr(opts, key) - - info = opts.data_labels[key] - args = info.component_args() if callable(info.component_args) else info.component_args or {} - args = {k: v for k, v in args.items() if k not in {'precision'}} - - return gr.update(value=value, **args) - - def create_override_settings_dropdown(tabname, row): dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) @@ -454,6 +429,8 @@ def create_ui(): with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): + modules.scripts.scripts_txt2img.prepare_ui() + for category in ordered_ui_categories(): if category == "sampler": steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") @@ -522,6 +499,9 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + else: + modules.scripts.scripts_txt2img.setup_ui_for_section(category) + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] for component in hr_resolution_preview_inputs: @@ -778,6 +758,8 @@ def create_ui(): with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + modules.scripts.scripts_img2img.prepare_ui() + for category in ordered_ui_categories(): if category == "sampler": steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") @@ -887,6 +869,8 @@ def create_ui(): inputs=[], outputs=[inpaint_controls, mask_alpha], ) + else: + modules.scripts.scripts_img2img.setup_ui_for_section(category) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) @@ -1460,195 +1444,10 @@ def create_ui(): outputs=[], ) - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {t} for key {key}') - - elem_id = f"setting_{key}" - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) - components = [] - component_dict = {} - shared.settings_components = component_dict - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return get_value_for_setting(key), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = opts.quicksettings_list - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - current_row = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_row.__exit__() - current_tab.__exit__() - - gr.Group() - current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) - current_tab.__enter__() - current_row = gr.Column(variant='compact') - current_row.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_row.__exit__() - current_tab.__exit__() - - with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"): - loadsave.create_ui() - - with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - with gr.Row(): - unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") - reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") - - with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): - gr.HTML(shared.html("licenses.html"), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - - def unload_sd_weights(): - modules.sd_models.unload_model_weights() - - def reload_sd_weights(): - modules.sd_models.reload_model_weights() - - unload_sd_model.click( - fn=unload_sd_weights, - inputs=[], - outputs=[] - ) - - reload_sd_model.click( - fn=reload_sd_weights, - inputs=[], - outputs=[] - ) - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - restart_gradio.click( - fn=shared.state.request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) + settings = ui_settings.UiSettings() + settings.create_ui(loadsave, dummy_component) interfaces = [ (txt2img_interface, "txt2img", "txt2img"), @@ -1660,7 +1459,7 @@ def create_ui(): ] interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] + interfaces += [(settings.interface, "Settings", "settings")] extensions_interface = ui_extensions.create_ui() interfaces += [(extensions_interface, "Extensions", "extensions")] @@ -1670,10 +1469,7 @@ def create_ui(): shared.tab_names.append(label) with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings", variant="compact"): - for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component + settings.add_quicksettings() parameters_copypaste.connect_paste_params_buttons() @@ -1704,49 +1500,12 @@ def create_ui(): footer = footer.format(versions=versions_html()) gr.HTML(footer, elem_id="footer") - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for _i, k, _item in quicksettings_list: - component = component_dict[k] - info = opts.data_labels[k] - - change_handler = component.release if hasattr(component, 'release') else component.change - change_handler( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - show_progress=info.refresh is not None, - ) + settings.add_functionality(demo) update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") - text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) + settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) - button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) - button_set_checkpoint.click( - fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), - _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", - inputs=[component_dict['sd_model_checkpoint'], dummy_component], - outputs=[component_dict['sd_model_checkpoint'], text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [get_value_for_setting(key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - queue=False, - ) - def modelmerger(*args): try: results = modules.extras.run_modelmerger(*args) @@ -1779,7 +1538,7 @@ def create_ui(): primary_model_name, secondary_model_name, tertiary_model_name, - component_dict['sd_model_checkpoint'], + settings.component_dict['sd_model_checkpoint'], modelmerger_result, ] ) @@ -1793,70 +1552,6 @@ def create_ui(): return demo -def webpath(fn): - if fn.startswith(script_path): - web_path = os.path.relpath(fn, script_path).replace('\\', '/') - else: - web_path = os.path.abspath(fn) - - return f'file={web_path}?{os.path.getmtime(fn)}' - - -def javascript_html(): - # Ensure localization is in `window` before scripts - head = f'\n' - - script_js = os.path.join(script_path, "script.js") - head += f'\n' - - for script in modules.scripts.list_scripts("javascript", ".js"): - head += f'\n' - - for script in modules.scripts.list_scripts("javascript", ".mjs"): - head += f'\n' - - if cmd_opts.theme: - head += f'\n' - - return head - - -def css_html(): - head = "" - - def stylesheet(fn): - return f'' - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - head += stylesheet(cssfile) - - if os.path.exists(os.path.join(data_path, "user.css")): - head += stylesheet(os.path.join(data_path, "user.css")) - - return head - - -def reload_javascript(): - js = javascript_html() - css = css_html() - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace(b'', f'{js}'.encode("utf8")) - res.body = res.body.replace(b'', f'{css}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - def versions_html(): import torch import launch diff --git a/modules/ui_common.py b/modules/ui_common.py index 5a9204a4..57c2d0ad 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -10,8 +10,11 @@ import subprocess as sp from modules import call_queue, shared from modules.generation_parameters_copypaste import image_from_url_text import modules.images +from modules.ui_components import ToolButton + folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 def update_generation_info(generation_info, html_info, img_index): @@ -216,3 +219,23 @@ Requested path was: {f} )) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py new file mode 100644 index 00000000..b824b113 --- /dev/null +++ b/modules/ui_gradio_extensions.py @@ -0,0 +1,69 @@ +import os +import gradio as gr + +from modules import localization, shared, scripts +from modules.paths import script_path, data_path + + +def webpath(fn): + if fn.startswith(script_path): + web_path = os.path.relpath(fn, script_path).replace('\\', '/') + else: + web_path = os.path.abspath(fn) + + return f'file={web_path}?{os.path.getmtime(fn)}' + + +def javascript_html(): + # Ensure localization is in `window` before scripts + head = f'\n' + + script_js = os.path.join(script_path, "script.js") + head += f'\n' + + for script in scripts.list_scripts("javascript", ".js"): + head += f'\n' + + for script in scripts.list_scripts("javascript", ".mjs"): + head += f'\n' + + if shared.cmd_opts.theme: + head += f'\n' + + return head + + +def css_html(): + head = "" + + def stylesheet(fn): + return f'' + + for cssfile in scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + head += stylesheet(cssfile) + + if os.path.exists(os.path.join(data_path, "user.css")): + head += stylesheet(os.path.join(data_path, "user.css")) + + return head + + +def reload_javascript(): + js = javascript_html() + css = css_html() + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace(b'', f'{js}'.encode("utf8")) + res.body = res.body.replace(b'', f'{css}'.encode("utf8")) + res.init_headers() + return res + + gr.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse diff --git a/modules/ui_settings.py b/modules/ui_settings.py new file mode 100644 index 00000000..7874298e --- /dev/null +++ b/modules/ui_settings.py @@ -0,0 +1,263 @@ +import gradio as gr + +from modules import ui_common, shared, script_callbacks, scripts, sd_models +from modules.call_queue import wrap_gradio_call +from modules.shared import opts +from modules.ui_components import FormRow +from modules.ui_gradio_extensions import reload_javascript + + +def get_value_for_setting(key): + value = getattr(opts, key) + + info = opts.data_labels[key] + args = info.component_args() if callable(info.component_args) else info.component_args or {} + args = {k: v for k, v in args.items() if k not in {'precision'}} + + return gr.update(value=value, **args) + + +def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {t} for key {key}') + + elem_id = f"setting_{key}" + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + +class UiSettings: + submit = None + result = None + interface = None + components = None + component_dict = None + dummy_component = None + quicksettings_list = None + quicksettings_names = None + text_settings = None + + def run_settings(self, *args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, self.components): + assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, self.components): + if comp == self.dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + + def run_settings_single(self, value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return get_value_for_setting(key), opts.dumpjson() + + def create_ui(self, loadsave, dummy_component): + self.components = [] + self.component_dict = {} + self.dummy_component = dummy_component + + shared.settings_components = self.component_dict + + script_callbacks.ui_settings_callback() + opts.reorder() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + self.result = gr.HTML(elem_id="settings_result") + + self.quicksettings_names = opts.quicksettings_list + self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'} + + self.quicksettings_list = [] + + previous_section = None + current_tab = None + current_row = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + gr.Group() + current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) + current_tab.__enter__() + current_row = gr.Column(variant='compact') + current_row.__enter__() + + previous_section = item.section + + if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings: + self.quicksettings_list.append((i, k, item)) + self.components.append(dummy_component) + elif section_must_be_skipped: + self.components.append(dummy_component) + else: + component = create_setting_component(k) + self.component_dict[k] = component + self.components.append(component) + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"): + loadsave.create_ui() + + with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + with gr.Row(): + unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") + reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") + + with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + + unload_sd_model.click( + fn=sd_models.unload_model_weights, + inputs=[], + outputs=[] + ) + + reload_sd_model.click( + fn=sd_models.reload_model_weights, + inputs=[], + outputs=[] + ) + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + restart_gradio.click( + fn=shared.state.request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + self.interface = settings_interface + + def add_quicksettings(self): + with gr.Row(elem_id="quicksettings", variant="compact"): + for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + self.component_dict[k] = component + + def add_functionality(self, demo): + self.submit.click( + fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]), + inputs=self.components, + outputs=[self.text_settings, self.result], + ) + + for _i, k, _item in self.quicksettings_list: + component = self.component_dict[k] + info = opts.data_labels[k] + + change_handler = component.release if hasattr(component, 'release') else component.change + change_handler( + fn=lambda value, k=k: self.run_settings_single(value, key=k), + inputs=[component], + outputs=[component, self.text_settings], + show_progress=info.refresh is not None, + ) + + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) + button_set_checkpoint.click( + fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component], + outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict] + + def get_settings_values(): + return [get_value_for_setting(key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[self.component_dict[k] for k in component_keys], + queue=False, + ) -- cgit v1.2.3 From a9674359ca4cd5ed8063a4dd23034450d80ba097 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 1 Jun 2023 19:52:04 +0300 Subject: revert the erroneous change for model setting added in df02498d --- modules/processing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index baa9b278..362ab4c2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -595,8 +595,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint - override_checkpoint = p.override_settings.get('sd_model_checkpoint') - if override_checkpoint is not None and sd_models.checkpoint_alisases.get(override_checkpoint) is None: + if sd_models.checkpoint_alisases.get(p.override_settings.get('sd_model_checkpoint')) is None: p.override_settings.pop('sd_model_checkpoint', None) sd_models.reload_model_weights() -- cgit v1.2.3 From 51864790fd72386fbbbb015d24a43ce501ecaa4b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 2 Jun 2023 14:58:10 +0300 Subject: Simplify a bunch of `len(x) > 0`/`len(x) == 0` style expressions --- extensions-builtin/LDSR/sd_hijack_autoencoder.py | 3 ++- extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 4 ++-- extensions-builtin/Lora/extra_networks_lora.py | 4 ++-- extensions-builtin/Lora/lora.py | 4 ++-- .../extra-options-section/scripts/extra_options_section.py | 2 +- modules/api/api.py | 2 +- modules/call_queue.py | 2 +- modules/extra_networks_hypernet.py | 4 ++-- modules/generation_parameters_copypaste.py | 6 ++---- modules/images.py | 6 +++--- modules/img2img.py | 3 +-- modules/models/diffusion/ddpm_edit.py | 4 ++-- modules/processing.py | 3 ++- modules/prompt_parser.py | 6 +++--- modules/script_callbacks.py | 4 ++-- modules/sd_hijack_clip.py | 2 +- modules/sd_hijack_clip_old.py | 2 +- modules/textual_inversion/autocrop.py | 14 +++++++------- modules/textual_inversion/dataset.py | 2 +- modules/textual_inversion/preprocess.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 2 +- modules/ui.py | 2 +- modules/ui_extensions.py | 5 +++-- modules/ui_settings.py | 2 +- scripts/prompts_from_file.py | 3 +-- 25 files changed, 47 insertions(+), 48 deletions(-) (limited to 'modules/processing.py') diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 27a86e13..c29d274d 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -91,8 +91,9 @@ class VQModel(pl.LightningModule): del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") + if unexpected: print(f"Unexpected Keys: {unexpected}") def on_train_batch_end(self, *args, **kwargs): diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 631a08ef..04adc5eb 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -195,9 +195,9 @@ class DDPMV1(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index b5fea4d2..66ee9c85 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -9,14 +9,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_lora - if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: + if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"" for x in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) names = [] multipliers = [] for params in params_list: - assert len(params.items) > 0 + assert params.items names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index eec14712..af93991c 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -219,7 +219,7 @@ def load_lora(name, lora_on_disk): else: raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") - if len(keys_failed_to_match) > 0: + if keys_failed_to_match: print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") return lora @@ -267,7 +267,7 @@ def load_loras(names, multipliers=None): lora.multiplier = multipliers[i] if multipliers else 1.0 loaded_loras.append(lora) - if len(failed_to_load_loras) > 0: + if failed_to_load_loras: sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras)) diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 17f84184..a05e10d8 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -21,7 +21,7 @@ class ExtraOptionsSection(scripts.Script): self.setting_names = [] with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and len(shared.opts.extra_options) > 0 else gr.Group(), gr.Row(): + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row(): for setting_name in shared.opts.extra_options: with FormColumn(): comp = ui_settings.create_setting_component(setting_name) diff --git a/modules/api/api.py b/modules/api/api.py index d34ab422..555eefdb 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -280,7 +280,7 @@ class Api: script_args[0] = selectable_idx + 1 # Now check for always on scripts - if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): + if request.alwayson_scripts: for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = self.get_script(alwayson_script_name, script_runner) if alwayson_script is None: diff --git a/modules/call_queue.py b/modules/call_queue.py index 53af6d70..1b5e5273 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -21,7 +21,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): # if the first argument is a string that says "task(...)", it is treated as a job id - if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": + if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"): id_task = args[0] progress.add_task_to_queue(id_task) else: diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index aa2a14ef..b6a6dc0e 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_hypernetwork - if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: + if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional): hypernet_prompt_text = f"" p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) @@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): names = [] multipliers = [] for params in params_list: - assert len(params.items) > 0 + assert params.items names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 071bd9ea..237401a1 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -55,7 +55,7 @@ def image_from_url_text(filedata): if filedata is None: return None - if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): + if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False): filedata = filedata[0] if type(filedata) == dict and filedata.get("is_file", False): @@ -437,7 +437,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, vals_pairs = [f"{k}: {v}" for k, v in vals.items()] - return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0) + return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) paste_fields = paste_fields + [(override_settings_component, paste_settings)] @@ -454,5 +454,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, outputs=[], show_progress=False, ) - - diff --git a/modules/images.py b/modules/images.py index a12d252b..7bbfc3e0 100644 --- a/modules/images.py +++ b/modules/images.py @@ -406,7 +406,7 @@ class FilenameGenerator: prompt_no_style = self.prompt for style in shared.prompt_styles.get_style_prompts(self.p.styles): - if len(style) > 0: + if style: for part in style.split("{prompt}"): prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',') @@ -415,7 +415,7 @@ class FilenameGenerator: return sanitize_filename_part(prompt_no_style, replace_spaces=False) def prompt_words(self): - words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0] + words = [x for x in re_nonletters.split(self.prompt or "") if x] if len(words) == 0: words = ["empty"] return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False) @@ -423,7 +423,7 @@ class FilenameGenerator: def datetime(self, *args): time_datetime = datetime.datetime.now() - time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format + time_format = args[0] if (args and args[0] != "") else self.default_time_format try: time_zone = pytz.timezone(args[1]) if len(args) > 1 else None except pytz.exceptions.UnknownTimeZoneError: diff --git a/modules/img2img.py b/modules/img2img.py index 4c12c2c5..35c4facc 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -21,8 +21,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): is_inpaint_batch = False if inpaint_mask_dir: inpaint_masks = shared.listfiles(inpaint_mask_dir) - is_inpaint_batch = len(inpaint_masks) > 0 - if is_inpaint_batch: + is_inpaint_batch = bool(inpaint_masks) print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index 3fb76b65..b892d5fc 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -230,9 +230,9 @@ class DDPM(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): diff --git a/modules/processing.py b/modules/processing.py index 362ab4c2..9ebdb549 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -975,7 +975,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") if self.enable_hr and latent_scale_mode is None: - assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" + if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): + raise Exception(f"could not find upscaler named {self.hr_upscaler}") x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index b4aff704..0069d8b0 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -336,11 +336,11 @@ def parse_prompt_attention(text): round_brackets.append(len(res)) elif text == '[': square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: + elif weight is not None and round_brackets: multiply_range(round_brackets.pop(), float(weight)) - elif text == ')' and len(round_brackets) > 0: + elif text == ')' and round_brackets: multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == ']' and len(square_brackets) > 0: + elif text == ']' and square_brackets: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: parts = re.split(re_break, text) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index f755283c..77ee55ee 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -287,14 +287,14 @@ def list_unets_callback(): def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if len(stack) > 0 else 'unknown file' + filename = stack[0].filename if stack else 'unknown file' callbacks.append(ScriptCallback(filename, fun)) def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if len(stack) > 0 else 'unknown file' + filename = stack[0].filename if stack else 'unknown file' if filename == 'unknown file': return for callback_list in callback_map.values(): diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index cc6e8c21..3b5a7666 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk.multipliers += [weight] * emb_len position += embedding_length_in_tokens - if len(chunk.tokens) > 0 or len(chunks) == 0: + if chunk.tokens or not chunks: next_chunk(is_last=True) return chunks, token_count diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py index a3476e95..c5c6270b 100644 --- a/modules/sd_hijack_clip_old.py +++ b/modules/sd_hijack_clip_old.py @@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text self.hijack.comments += hijack_comments - if len(used_custom_terms) > 0: + if used_custom_terms: embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms) self.hijack.comments.append(f"Used embeddings: {embedding_names}") diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 8e667a4d..75705459 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -77,27 +77,27 @@ def focal_point(im, settings): pois = [] weight_pref_total = 0 - if len(corner_points) > 0: + if corner_points: weight_pref_total += settings.corner_points_weight - if len(entropy_points) > 0: + if entropy_points: weight_pref_total += settings.entropy_points_weight - if len(face_points) > 0: + if face_points: weight_pref_total += settings.face_points_weight corner_centroid = None - if len(corner_points) > 0: + if corner_points: corner_centroid = centroid(corner_points) corner_centroid.weight = settings.corner_points_weight / weight_pref_total pois.append(corner_centroid) entropy_centroid = None - if len(entropy_points) > 0: + if entropy_points: entropy_centroid = centroid(entropy_points) entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total pois.append(entropy_centroid) face_centroid = None - if len(face_points) > 0: + if face_points: face_centroid = centroid(face_points) face_centroid.weight = settings.face_points_weight / weight_pref_total pois.append(face_centroid) @@ -187,7 +187,7 @@ def image_face_points(im, settings): except Exception: continue - if len(faces) > 0: + if faces: rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] return [] diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b9621fc9..7ee05061 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -32,7 +32,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): - re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None + re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None self.placeholder_token = placeholder_token diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index a009d8e8..0d4c3f84 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -47,7 +47,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti caption += shared.interrogator.generate_caption(image) if params.process_caption_deepbooru: - if len(caption) > 0: + if caption: caption += ", " caption += deepbooru.model.tag_multi(image) @@ -67,7 +67,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti caption = caption.strip() - if len(caption) > 0: + if caption: with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: file.write(caption) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 8da050ca..bb6f211c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -251,7 +251,7 @@ class EmbeddingDatabase: if self.previously_displayed_embeddings != displayed_embeddings: self.previously_displayed_embeddings = displayed_embeddings print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") - if len(self.skipped_embeddings) > 0: + if self.skipped_embeddings: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): diff --git a/modules/ui.py b/modules/ui.py index b7459f08..9a025cca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -398,7 +398,7 @@ def create_override_settings_dropdown(tabname, row): dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) dropdown.change( - fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), + fn=lambda x: gr.Dropdown.update(visible=bool(x)), inputs=[dropdown], outputs=[dropdown], ) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 3140ed64..65173e06 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -333,7 +333,8 @@ def install_extension_from_url(dirname, url, branch_name=None): assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' normalized_url = normalize_git_url(url) - assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' + if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url): + raise Exception(f'Extension with this URL is already installed: {url}') tmpdir = os.path.join(paths.data_path, "tmp", dirname) @@ -449,7 +450,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" existing = installed_extension_urls.get(normalize_git_url(url), None) extension_tags = extension_tags + ["installed"] if existing else extension_tags - if len([x for x in extension_tags if x in tags_to_hide]) > 0: + if any(x for x in extension_tags if x in tags_to_hide): hidden += 1 continue diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 7874298e..2688d8c2 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -81,7 +81,7 @@ class UiSettings: opts.save(shared.config_filename) except RuntimeError: return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.' def run_settings_single(self, value, key): if not opts.same_type(value, opts.data_labels[key].default): diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 83a2f220..50320d55 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -121,8 +121,7 @@ class Script(scripts.Script): return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): - lines = [x.strip() for x in prompt_txt.splitlines()] - lines = [x for x in lines if len(x) > 0] + lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x] p.do_not_save_grid = True -- cgit v1.2.3 From f098e726d3e63aae8a6276ce83c55ac905c4379c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 4 Jun 2023 04:24:44 +0900 Subject: fix conds caching with extra network --- modules/extra_networks.py | 3 +++ modules/processing.py | 24 ++++++++++++------------ 2 files changed, 15 insertions(+), 12 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index f4743928..5c5c9a53 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -32,6 +32,9 @@ class ExtraNetworkParams: else: self.positional.append(item) + def __eq__(self, other): + return self.items == other.items and self.positional == other.positional and self.named == other.named + class ExtraNetwork: def __init__(self, name): diff --git a/modules/processing.py b/modules/processing.py index 362ab4c2..22ddc256 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -171,6 +171,7 @@ class StableDiffusionProcessing: self.prompts = None self.negative_prompts = None + self.extra_network_data = None self.seeds = None self.subseeds = None @@ -311,7 +312,7 @@ class StableDiffusionProcessing: self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] - def get_conds_with_caching(self, function, required_prompts, steps, cache): + def get_conds_with_caching(self, function, required_prompts, steps, cache, extra_network_data): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) using a cache to store the result if the same arguments have been used before. @@ -321,21 +322,21 @@ class StableDiffusionProcessing: have been used before. The second element is where the previously computed result is stored. """ - if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]: + if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: return cache[1] with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) - cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) + cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) return cache[1] def setup_conds(self): sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c) + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data) def parse_extra_network_prompts(self): self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts) @@ -681,7 +682,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - extra_network_data = None for n in range(p.n_iter): p.iteration = n @@ -702,11 +702,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(p.prompts) == 0: break - extra_network_data = p.parse_extra_network_prompts() + p.extra_network_data = p.parse_extra_network_prompts() if not p.disable_extra_networks: with devices.autocast(): - extra_networks.activate(p, extra_network_data) + extra_networks.activate(p, p.extra_network_data) if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) @@ -828,8 +828,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) - if not p.disable_extra_networks and extra_network_data: - extra_networks.deactivate(p, extra_network_data) + if not p.disable_extra_networks and p.extra_network_data: + extra_networks.deactivate(p, p.extra_network_data) devices.torch_gc() @@ -1101,8 +1101,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): super().setup_conds() if self.enable_hr: - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c) + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data) def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts() -- cgit v1.2.3 From 1c9d1b0ee06cffe0ba1b33a5795755cf39e5dde2 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 4 Jun 2023 05:19:34 +0900 Subject: simplify self.extra_network_data --- modules/processing.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 22ddc256..fae83788 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -339,9 +339,7 @@ class StableDiffusionProcessing: self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data) def parse_extra_network_prompts(self): - self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts) - - return extra_network_data + self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) class Processed: @@ -702,7 +700,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(p.prompts) == 0: break - p.extra_network_data = p.parse_extra_network_prompts() + p.parse_extra_network_prompts() if not p.disable_extra_networks: with devices.autocast(): -- cgit v1.2.3 From 1ca5e76f7b122ba35fc807350624a8d3fc25058a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 4 Jun 2023 13:07:22 +0300 Subject: fix for conds of second hires fox pass being calculated using first pass's networks, and add an option to revert to old behavior --- modules/lowvram.py | 6 ++++++ modules/processing.py | 32 +++++++++++++++++++++++++++++--- modules/shared.py | 1 + 3 files changed, 36 insertions(+), 3 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/lowvram.py b/modules/lowvram.py index e254cc13..d95bcfbf 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -15,6 +15,8 @@ def send_everything_to_cpu(): def setup_for_low_vram(sd_model, use_medvram): + sd_model.lowvram = True + parents = {} def send_me_to_gpu(module, _): @@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram): diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) for block in diff_model.output_blocks: block.register_forward_pre_hook(send_me_to_gpu) + + +def is_enabled(sd_model): + return getattr(sd_model, 'lowvram', False) diff --git a/modules/processing.py b/modules/processing.py index fae83788..b65c7beb 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -739,7 +739,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: del samples_ddim - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + if lowvram.is_enabled(shared.sd_model): lowvram.send_everything_to_cpu() devices.torch_gc() @@ -894,6 +894,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_negative_prompts = None self.hr_extra_network_data = None + self.cached_hr_uc = [None, None] + self.cached_hr_c = [None, None] self.hr_c = None self.hr_uc = None @@ -1056,6 +1058,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast(): extra_networks.activate(self, self.hr_extra_network_data) + with devices.autocast(): + self.calculate_hr_conds() + sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) @@ -1067,6 +1072,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples def close(self): + self.cached_hr_uc = [None, None] + self.cached_hr_c = [None, None] self.hr_c = None self.hr_uc = None @@ -1095,12 +1102,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts] self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts] + def calculate_hr_conds(self): + if self.hr_c is not None: + return + + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_hr_uc, self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_hr_c, self.hr_extra_network_data) + def setup_conds(self): super().setup_conds() + self.hr_uc = None + self.hr_c = None + if self.enable_hr: - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data) + if shared.opts.hires_fix_use_firstpass_conds: + self.calculate_hr_conds() + + elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded + with devices.autocast(): + extra_networks.activate(self, self.hr_extra_network_data) + + self.calculate_hr_conds() + + with devices.autocast(): + extra_networks.activate(self, self.extra_network_data) def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts() diff --git a/modules/shared.py b/modules/shared.py index 7d056a4d..2bd7c6ec 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -429,6 +429,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), + "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { -- cgit v1.2.3 From fbf88343deff2f0d6c1c375ec858c094ed9fa260 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 4 Jun 2023 16:29:02 +0300 Subject: prevent calculating cons for second pass of hires fix when they are the same as for the first pass --- modules/processing.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index b65c7beb..230053e4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -312,7 +312,7 @@ class StableDiffusionProcessing: self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] - def get_conds_with_caching(self, function, required_prompts, steps, cache, extra_network_data): + def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) using a cache to store the result if the same arguments have been used before. @@ -321,9 +321,15 @@ class StableDiffusionProcessing: representing the previously used arguments, or None if no arguments have been used before. The second element is where the previously computed result is stored. + + caches is a list with items described above. """ - if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: - return cache[1] + + for cache in caches: + if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: + return cache[1] + + cache = caches[0] with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) @@ -335,8 +341,8 @@ class StableDiffusionProcessing: sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.extra_network_data) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data) + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) @@ -1106,8 +1112,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.hr_c is not None: return - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_hr_uc, self.hr_extra_network_data) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_hr_c, self.hr_extra_network_data) + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) def setup_conds(self): super().setup_conds() -- cgit v1.2.3 From 7f2214aa2b9a88bc1631f3f1d164a75a3583c851 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 8 Jun 2023 13:53:02 +0900 Subject: persistent conds cache Update shared.py --- modules/processing.py | 26 ++++++++++++++++---------- modules/shared.py | 1 + 2 files changed, 17 insertions(+), 10 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index f8225b83..1169bd7f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -106,6 +106,9 @@ class StableDiffusionProcessing: """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing """ + cached_uc = [None, None] + cached_c = [None, None] + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -176,8 +179,8 @@ class StableDiffusionProcessing: self.subseeds = None self.step_multiplier = 1 - self.cached_uc = [None, None] - self.cached_c = [None, None] + self.cached_uc = StableDiffusionProcessing.cached_uc + self.cached_c = StableDiffusionProcessing.cached_c self.uc = None self.c = None @@ -289,8 +292,9 @@ class StableDiffusionProcessing: self.sampler = None self.c = None self.uc = None - self.cached_c = [None, None] - self.cached_uc = [None, None] + if not opts.experimental_persistent_cond_cache: + StableDiffusionProcessing.cached_c = [None, None] + StableDiffusionProcessing.cached_uc = [None, None] def get_token_merging_ratio(self, for_hr=False): if for_hr: @@ -324,7 +328,6 @@ class StableDiffusionProcessing: caches is a list with items described above. """ - for cache in caches: if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: return cache[1] @@ -340,7 +343,6 @@ class StableDiffusionProcessing: def setup_conds(self): sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) @@ -868,6 +870,8 @@ def old_hires_fix_first_pass_dimensions(width, height): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None + cached_hr_uc = [None, None] + cached_hr_c = [None, None] def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): super().__init__(**kwargs) @@ -900,8 +904,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_negative_prompts = None self.hr_extra_network_data = None - self.cached_hr_uc = [None, None] - self.cached_hr_c = [None, None] + self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc + self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c self.hr_c = None self.hr_uc = None @@ -1079,10 +1083,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples def close(self): - self.cached_hr_uc = [None, None] - self.cached_hr_c = [None, None] + super().close() self.hr_c = None self.hr_uc = None + if not opts.experimental_persistent_cond_cache: + StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None] + StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None] def setup_prompts(self): super().setup_prompts() diff --git a/modules/shared.py b/modules/shared.py index c9ee2dd1..91c31d55 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -421,6 +421,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"), + "experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { -- cgit v1.2.3 From 1503af60b0550d1a57c95a500df1fe1390178a27 Mon Sep 17 00:00:00 2001 From: Splendide Imaginarius <119545140+Splendide-Imaginarius@users.noreply.github.com> Date: Thu, 11 May 2023 22:03:44 +0000 Subject: Split mask blur into X and Y components Prequisite to fixing Outpainting MK2 mask blur bug. --- modules/processing.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index f8225b83..73ff10f4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1150,7 +1150,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): + def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): super().__init__(**kwargs) self.init_images = init_images @@ -1161,7 +1161,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_mask = mask self.latent_mask = None self.mask_for_overlay = None - self.mask_blur = mask_blur + if mask_blur is not None: + mask_blur_x = mask_blur + mask_blur_y = mask_blur + self.mask_blur_x = mask_blur_x + self.mask_blur_y = mask_blur_y self.inpainting_fill = inpainting_fill self.inpaint_full_res = inpaint_full_res self.inpaint_full_res_padding = inpaint_full_res_padding @@ -1183,8 +1187,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) - if self.mask_blur > 0: - image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) + if self.mask_blur_x > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x) + image_mask = Image.fromarray(np_mask) + + if self.mask_blur_y > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) + image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: self.mask_for_overlay = image_mask -- cgit v1.2.3 From cfdd1b941897eee994e92fc2149c0de7da6eed4a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 9 Jun 2023 22:47:27 +0300 Subject: linter --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index b700a718..8da73884 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -7,7 +7,7 @@ import hashlib import torch import numpy as np -from PIL import Image, ImageFilter, ImageOps +from PIL import Image, ImageOps import random import cv2 from skimage import exposure -- cgit v1.2.3 From d3c86e5178725b11a4679097f0aefb0a9fc90014 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Wed, 14 Jun 2023 14:03:44 -0500 Subject: Note the Gradio user in the Exif data --- modules/img2img.py | 5 ++++- modules/processing.py | 3 +++ modules/txt2img.py | 6 ++++-- 3 files changed, 11 insertions(+), 3 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/img2img.py b/modules/img2img.py index d704bf90..83bd7857 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -2,6 +2,7 @@ import os import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError +import gradio as gr from modules import sd_samplers from modules.generation_parameters_copypaste import create_override_settings_dict @@ -78,7 +79,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -160,6 +161,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s p.scripts = modules.scripts.scripts_img2img p.script_args = args + p.user = request.username + if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index d22b353f..3e8d153e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -180,6 +180,8 @@ class StableDiffusionProcessing: self.uc = None self.c = None + self.user = None + @property def sd_model(self): return shared.sd_model @@ -578,6 +580,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, + "User": p.user, } generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) diff --git a/modules/txt2img.py b/modules/txt2img.py index 2e7d202d..6aa79f23 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -4,10 +4,10 @@ from modules.generation_parameters_copypaste import create_override_settings_dic from modules.shared import opts, cmd_opts import modules.shared as shared from modules.ui import plaintext_to_html +import gradio as gr - -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -48,6 +48,8 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step p.scripts = modules.scripts.scripts_txt2img p.script_args = args + p.user = request.username + if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) -- cgit v1.2.3 From f603275d84301b5ee952683e951dd1aad72ba615 Mon Sep 17 00:00:00 2001 From: Jared Deckard Date: Thu, 15 Jun 2023 10:55:53 -0500 Subject: Add an opt-in infotext user name setting --- modules/processing.py | 2 +- modules/shared.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 3e8d153e..a0cc8db2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -580,7 +580,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, - "User": p.user, + "User": p.user if opts.add_user_name_to_info else None, } generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) diff --git a/modules/shared.py b/modules/shared.py index 271a062d..4c639a21 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -496,6 +496,7 @@ options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('infotext', "Infotext"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), + "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"), "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), })) -- cgit v1.2.3 From 27e9e3f6fa41c10eb5256662ddf3643dee933810 Mon Sep 17 00:00:00 2001 From: stablegeniusdiffuser Date: Mon, 19 Jun 2023 20:36:44 +0200 Subject: Add use_main_prompt parameter to use proper metadata for batch run grids or individual images --- modules/processing.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 8da73884..1d97e95e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -549,7 +549,7 @@ def program_version(): return res -def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0): +def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False): index = position_in_batch + iteration * p.batch_size clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) @@ -589,9 +589,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) + prompt_text = p.prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else "" - return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() + return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip() def process_images(p: StableDiffusionProcessing) -> Processed: @@ -663,8 +664,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else: p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] - def infotext(iteration=0, position_in_batch=0): - return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) + def infotext(iteration=0, position_in_batch=0, use_main_prompt=False): + return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt) if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() @@ -824,7 +825,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: grid = images.image_grid(output_images, p.batch_size) if opts.return_grid: - text = infotext() + text = infotext(use_main_prompt=True) infotexts.insert(0, text) if opts.enable_pnginfo: grid.info["parameters"] = text @@ -832,7 +833,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True) if not p.disable_extra_networks and p.extra_network_data: extra_networks.deactivate(p, p.extra_network_data) -- cgit v1.2.3 From b0ec69b360835a901a1aa57df1f7c8c9d55bf31c Mon Sep 17 00:00:00 2001 From: hako-mikan <122196982+hako-mikan@users.noreply.github.com> Date: Wed, 28 Jun 2023 18:37:08 +0900 Subject: add 'before_hr callback' script callback --- modules/processing.py | 3 +++ modules/scripts.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 8da73884..35463c37 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1074,6 +1074,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) + if self.scripts is not None: + self.scripts.before_hr(self) + samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) diff --git a/modules/scripts.py b/modules/scripts.py index 99bf836a..6485f398 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -186,6 +186,11 @@ class Script: return f'script_{tabname}{title}_{item_id}' + def before_hr(self, p ,*args): + """ + This function is called before hires fix start. + """ + pass current_basedir = paths.script_path @@ -548,6 +553,15 @@ class ScriptRunner: self.scripts[si].args_to = args_to + def before_hr(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.before_hr(p, *script_args) + except Exception: + errors.report(f"Error running before_hr: {script.filename}", exc_info=True) + + scripts_txt2img: ScriptRunner = None scripts_img2img: ScriptRunner = None scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None -- cgit v1.2.3 From 7f46f81dd7b517e829395734750f0eb8360675d4 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Sat, 1 Jul 2023 17:20:56 -0600 Subject: Change default seed_resize to 0 --- modules/processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 8da73884..9e838aad 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -109,7 +109,7 @@ class StableDiffusionProcessing: cached_uc = [None, None] cached_c = [None, None] - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = 0, seed_resize_from_w: int = 0, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) -- cgit v1.2.3 From f731a728c68035ee36317ed0096ac5ecbfd50553 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Mon, 3 Jul 2023 11:41:10 -0600 Subject: Check seed_resize_from <= 0 --- modules/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 9e838aad..dc552121 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -109,7 +109,7 @@ class StableDiffusionProcessing: cached_uc = [None, None] cached_c = [None, None] - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = 0, seed_resize_from_w: int = 0, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -573,7 +573,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), - "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), + "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None, "Clip skip": None if clip_skip <= 1 else clip_skip, -- cgit v1.2.3