From 740070ea9cdb254209f66417418f2a4af8b099d6 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Mon, 26 Sep 2022 09:29:50 -0500 Subject: Re-implement universal model loading --- modules/shared.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index c32da110..1444040d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,11 +16,11 @@ import modules.sd_models sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file - +model_path = os.path.join(script_path, 'models') parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) -parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",) +parser.add_argument("--ckpt-dir", type=str, default=model_path, help="path to directory with stable diffusion checkpoints",) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") @@ -34,8 +34,12 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") -parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN')) -parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR')) +parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model(s)", default=os.path.join(model_path, 'Codeformer')) +parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model(s)", default=os.path.join(model_path, 'GFPGAN')) +parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN models", default=os.path.join(model_path, 'ESRGAN')) +parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN models", default=os.path.join(model_path, 'RealESRGAN')) +parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR models", default=os.path.join(model_path, 'SwinIR')) +parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR models", default=os.path.join(model_path, 'LDSR')) parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.3 From 11875f586323cea7c5b8398976449788a83dee76 Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Tue, 27 Sep 2022 11:01:13 -0500 Subject: Use model loader with stable-diffusion too. Hook the model loader into the SD_models file. Add default url/download if checkpoint is not found. Add matching stablediffusion-models-path argument. Add message that --ckpt-dir will be removed in the future, but have it pipe to stablediffusion-models-path for now. Update help strings for models-path args so they're more or less uniform. Move sd_model "setup" call to webUI with the others. Ensure "cleanup_models" method moves existing models to the new locations, including SD, and that we aren't deleting folders that still have stuff in them. --- modules/modelloader.py | 25 ++++++++++++++++++------- modules/sd_models.py | 48 +++++++++++++++++++++++++++++------------------- modules/shared.py | 28 +++++++++++++++++----------- webui.py | 1 + 4 files changed, 65 insertions(+), 37 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index 9520a681..2ee364f0 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -45,7 +45,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if file not in existing: path = os.path.join(place, file) existing.append(path) - if model_url is not None: + if model_url is not None and len(existing) == 0: if dl_name is not None: model_file = load_file_from_url(url=model_url, model_dir=model_path, file_name=dl_name, progress=True) else: @@ -69,7 +69,13 @@ def friendly_name(file: str): def cleanup_models(): + # This code could probably be more efficient if we used a tuple list or something to store the src/destinations + # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler + # somehow auto-register and just do these things... root_path = script_path + src_path = models_path + dest_path = os.path.join(models_path, "Stable-diffusion") + move_files(src_path, dest_path, ".ckpt") src_path = os.path.join(root_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN") move_files(src_path, dest_path) @@ -84,20 +90,25 @@ def cleanup_models(): move_files(src_path, dest_path) -def move_files(src_path: str, dest_path: str): +def move_files(src_path: str, dest_path: str, ext_filter: str = None): try: if not os.path.exists(dest_path): os.makedirs(dest_path) if os.path.exists(src_path): for file in os.listdir(src_path): - if os.path.isfile(file): - fullpath = os.path.join(src_path, file) - print("Moving file: %s to %s" % (fullpath, dest_path)) + fullpath = os.path.join(src_path, file) + if os.path.isfile(fullpath): + print(f"Checking file {file} in {src_path}") + if ext_filter is not None: + if ext_filter not in file: + continue + print(f"Moving {file} from {src_path} to {dest_path}.") try: shutil.move(fullpath, dest_path) except: pass - print("Removing folder: %s" % src_path) - shutil.rmtree(src_path, True) + if len(os.listdir(src_path)) == 0: + print(f"Removing empty folder: {src_path}") + shutil.rmtree(src_path, True) except: pass \ No newline at end of file diff --git a/modules/sd_models.py b/modules/sd_models.py index dc81b0dc..89b7d276 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,13 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared +from modules import shared, modelloader +from modules.paths import models_path + +model_dir = "Stable-diffusion" +model_path = os.path.join(models_path, model_dir) +model_name = "sd-v1-4.ckpt" +model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1" CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash']) checkpoints_list = {} @@ -23,23 +29,28 @@ except Exception: pass -def list_models(): - checkpoints_list.clear() +def modeltitle(path, h): + abspath = os.path.abspath(path) - model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) + if abspath.startswith(model_dir): + name = abspath.replace(model_dir, '') + else: + name = os.path.basename(path) - def modeltitle(path, h): - abspath = os.path.abspath(path) + if name.startswith("\\") or name.startswith("/"): + name = name[1:] - if abspath.startswith(model_dir): - name = abspath.replace(model_dir, '') - else: - name = os.path.basename(path) + return f'{name} [{h}]' - if name.startswith("\\") or name.startswith("/"): - name = name[1:] - return f'{name} [{h}]' +def setup_model(dirname): + global model_path + global model_name + global model_url + if not os.path.exists(model_path): + os.makedirs(model_path) + checkpoints_list.clear() + model_list = modelloader.load_models(model_path, model_url, dirname, model_name, ext_filter=".ckpt") cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): @@ -47,13 +58,12 @@ def list_models(): title = modeltitle(cmd_ckpt, h) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h) elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: - print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr) + print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - if os.path.exists(model_dir): - for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True): - h = model_hash(filename) - title = modeltitle(filename, h) - checkpoints_list[title] = CheckpointInfo(filename, title, h) + for filename in model_list: + h = model_hash(filename) + title = modeltitle(filename, h) + checkpoints_list[title] = CheckpointInfo(filename, title, h) def model_hash(filename): diff --git a/modules/shared.py b/modules/shared.py index 1444040d..e617bce4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -20,7 +20,8 @@ model_path = os.path.join(script_path, 'models') parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) -parser.add_argument("--ckpt-dir", type=str, default=model_path, help="path to directory with stable diffusion checkpoints",) +# This should be deprecated, but we'll leave it for a few iterations +parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints (Deprecated, use '--stablediffusion-models-path'", ) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") @@ -34,12 +35,13 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") -parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model(s)", default=os.path.join(model_path, 'Codeformer')) -parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model(s)", default=os.path.join(model_path, 'GFPGAN')) -parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN models", default=os.path.join(model_path, 'ESRGAN')) -parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN models", default=os.path.join(model_path, 'RealESRGAN')) -parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR models", default=os.path.join(model_path, 'SwinIR')) -parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR models", default=os.path.join(model_path, 'LDSR')) +parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer')) +parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN')) +parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) +parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) +parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR')) +parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) +parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR')) parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") @@ -57,7 +59,10 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) cmd_opts = parser.parse_args() - +if cmd_opts.ckpt_dir is not None: + print("The 'ckpt-dir' arg is deprecated in favor of the 'stablediffusion-models-path' argument and will be " + "removed in a future release. Please use the new option if you wish to use a custom checkpoint directory.") + cmd_opts.__setattr__("stablediffusion-models-path", cmd_opts.ckpt_dir) device = get_optimal_device() batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) @@ -65,6 +70,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram config_filename = cmd_opts.ui_settings_file + class State: interrupted = False job = "" @@ -98,8 +104,8 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename) interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] - -modules.sd_models.list_models() +# This was moved to webui.py with the other model "setup" calls. +# modules.sd_models.list_models() def realesrgan_models_names(): @@ -195,7 +201,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), - "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), + "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"), diff --git a/webui.py b/webui.py index e71a217c..15e49bf6 100644 --- a/webui.py +++ b/webui.py @@ -23,6 +23,7 @@ from modules.paths import script_path from modules.shared import cmd_opts modelloader.cleanup_models() +modules.sd_models.setup_model(cmd_opts.stablediffusion_models_path) codeformer.setup_model(cmd_opts.codeformer_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path) shared.face_restorers.append(modules.face_restoration.FaceRestoration()) -- cgit v1.2.3 From 0dce0df1ee63b2f158805c1a1f1a3743cc4a104b Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Thu, 29 Sep 2022 17:46:23 -0500 Subject: Holy $hit. Yep. Fix gfpgan_model_arch requirement(s). Add Upscaler base class, move from images. Add a lot of methods to Upscaler. Re-work all the child upscalers to be proper classes. Add BSRGAN scaler. Add ldsr_model_arch class, removing the dependency for another repo that just uses regular latent-diffusion stuff. Add one universal method that will always find and load new upscaler models without having to add new "setup_model" calls. Still need to add command line params, but that could probably be automated. Add a "self.scale" property to all Upscalers so the scalers themselves can do "things" in response to the requested upscaling size. Ensure LDSR doesn't get stuck in a longer loop of "upscale/downscale/upscale" as we try to reach the target upscale size. Add typehints for IDE sanity. PEP-8 improvements. Moar. --- launch.py | 11 ++- modules/bsrgan_model.py | 79 +++++++++++++++ modules/bsrgan_model_arch.py | 103 ++++++++++++++++++++ modules/esrgan_model.py | 227 +++++++++++++++++++++---------------------- modules/extras.py | 35 ++++--- modules/gfpgan_model.py | 58 +++++++---- modules/gfpgan_model_arch.py | 150 ---------------------------- modules/images.py | 84 +++++++--------- modules/ldsr_model.py | 103 +++++++------------- modules/ldsr_model_arch.py | 223 ++++++++++++++++++++++++++++++++++++++++++ modules/modelloader.py | 74 ++++++++++---- modules/realesrgan_model.py | 209 ++++++++++++++++++++------------------- modules/sd_models.py | 3 +- modules/sd_samplers.py | 4 +- modules/shared.py | 18 ++-- modules/swinir_model.py | 157 +++++++++++++----------------- modules/upscaler.py | 121 +++++++++++++++++++++++ webui.py | 9 +- 18 files changed, 1018 insertions(+), 650 deletions(-) create mode 100644 modules/bsrgan_model.py create mode 100644 modules/bsrgan_model_arch.py delete mode 100644 modules/gfpgan_model_arch.py create mode 100644 modules/ldsr_model_arch.py create mode 100644 modules/upscaler.py (limited to 'modules/shared.py') diff --git a/launch.py b/launch.py index 4462631c..77edea26 100644 --- a/launch.py +++ b/launch.py @@ -1,5 +1,5 @@ # this scripts installs necessary requirements and launches main program in webui.py - +import shutil import subprocess import os import sys @@ -22,7 +22,6 @@ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "6 taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") -ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH',"abf33e7002d59d9085081bce93ec798dcabd49af") args = shlex.split(commandline_args) @@ -122,9 +121,11 @@ git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-di git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) -# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier -git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash) - +if os.path.isdir(repo_dir('latent-diffusion')): + try: + shutil.rmtree(repo_dir('latent-diffusion')) + except: + pass if not is_installed("lpips"): run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py new file mode 100644 index 00000000..77141545 --- /dev/null +++ b/modules/bsrgan_model.py @@ -0,0 +1,79 @@ +import os.path +import sys +import traceback + +import PIL.Image +import numpy as np +import torch +from basicsr.utils.download_util import load_file_from_url + +import modules.upscaler +from modules import shared, modelloader +from modules.bsrgan_model_arch import RRDBNet +from modules.paths import models_path + + +class UpscalerBSRGAN(modules.upscaler.Upscaler): + def __init__(self, dirname): + self.name = "BSRGAN" + self.model_path = os.path.join(models_path, self.name) + self.model_name = "BSRGAN 4x" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" + self.user_path = dirname + super().__init__() + model_paths = self.find_models(ext_filter=[".pt", ".pth"]) + scalers = [] + if len(model_paths) == 0: + scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4) + scalers.append(scaler_data) + for file in model_paths: + if "http" in file: + name = self.model_name + else: + name = modelloader.friendly_name(file) + try: + scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) + scalers.append(scaler_data) + except Exception: + print(f"Error loading BSRGAN model: {file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + self.scalers = scalers + + def do_upscale(self, img: PIL.Image, selected_file): + torch.cuda.empty_cache() + model = self.load_model(selected_file) + if model is None: + return img + model.to(shared.device) + torch.cuda.empty_cache() + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(shared.device) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + torch.cuda.empty_cache() + return PIL.Image.fromarray(output, 'RGB') + + def load_model(self, path: str): + if "http" in path: + filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, + progress=True) + else: + filename = path + if not os.path.exists(filename) or filename is None: + print("Unable to load %s from %s" % (self.model_dir, filename)) + return None + print("Loading %s from %s" % (self.model_dir, filename)) + model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network + model.load_state_dict(torch.load(filename), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + return model + diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py new file mode 100644 index 00000000..d72647db --- /dev/null +++ b/modules/bsrgan_model_arch.py @@ -0,0 +1,103 @@ +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + self.sf = sf + print([in_nc, out_nc, nf, nb, gc, sf]) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + if self.sf==4: + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + if self.sf==4: + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out \ No newline at end of file diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 5e10c49c..ce841aa4 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,6 +1,4 @@ import os -import sys -import traceback import numpy as np import torch @@ -8,93 +6,119 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url import modules.esrgam_model_arch as arch -import modules.images -from modules import shared -from modules import shared, modelloader +from modules import shared, modelloader, images from modules.devices import has_mps from modules.paths import models_path +from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts -model_dir = "ESRGAN" -model_path = os.path.join(models_path, model_dir) -model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download" -model_name = "ESRGAN_x4" - - -def load_model(path: str, name: str): - global model_path - global model_url - global model_dir - global model_name - if "http" in path: - filename = load_file_from_url(url=model_url, model_dir=model_path, file_name="%s.pth" % model_name, progress=True) - else: - filename = path - if not os.path.exists(filename) or filename is None: - print("Unable to load %s from %s" % (model_dir, filename)) - return None - print("Loading %s from %s" % (model_dir, filename)) - # this code is adapted from https://github.com/xinntao/ESRGAN - pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) - crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) - - if 'conv_first.weight' in pretrained_net: - crt_model.load_state_dict(pretrained_net) - return crt_model - if 'model.0.weight' not in pretrained_net: - is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] - if is_realesrgan: - raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") - else: - raise Exception("The file is not a ESRGAN model.") +class UpscalerESRGAN(Upscaler): + def __init__(self, dirname): + self.name = "ESRGAN" + self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download" + self.model_name = "ESRGAN 4x" + self.scalers = [] + self.user_path = dirname + self.model_path = os.path.join(models_path, self.name) + super().__init__() + model_paths = self.find_models(ext_filter=[".pt", ".pth"]) + scalers = [] + if len(model_paths) == 0: + scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) + scalers.append(scaler_data) + for file in model_paths: + print(f"File: {file}") + if "http" in file: + name = self.model_name + else: + name = modelloader.friendly_name(file) + + scaler_data = UpscalerData(name, file, self, 4) + print(f"ESRGAN: Adding scaler {name}") + self.scalers.append(scaler_data) + + def do_upscale(self, img, selected_model): + model = self.load_model(selected_model) + if model is None: + return img + model.to(shared.device) + img = esrgan_upscale(model, img) + return img - crt_net = crt_model.state_dict() - load_net_clean = {} - for k, v in pretrained_net.items(): - if k.startswith('module.'): - load_net_clean[k[7:]] = v + def load_model(self, path: str): + if "http" in path: + filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, + file_name="%s.pth" % self.model_name, + progress=True) else: - load_net_clean[k] = v - pretrained_net = load_net_clean - - tbd = [] - for k, v in crt_net.items(): - tbd.append(k) - - # directly copy - for k, v in crt_net.items(): - if k in pretrained_net and pretrained_net[k].size() == v.size(): - crt_net[k] = pretrained_net[k] - tbd.remove(k) - - crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] - crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] - - for k in tbd.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[k] = pretrained_net[ori_k] - tbd.remove(k) - - crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] - crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] - crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] - crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] - crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] - crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] - crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] - crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] - crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] - crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] - - crt_model.load_state_dict(crt_net) - crt_model.eval() - return crt_model + filename = path + if not os.path.exists(filename) or filename is None: + print("Unable to load %s from %s" % (self.model_path, filename)) + return None + # this code is adapted from https://github.com/xinntao/ESRGAN + pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) + crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) + + if 'conv_first.weight' in pretrained_net: + crt_model.load_state_dict(pretrained_net) + return crt_model + + if 'model.0.weight' not in pretrained_net: + is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[ + "params_ema"] + if is_realesrgan: + raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") + else: + raise Exception("The file is not a ESRGAN model.") + + crt_net = crt_model.state_dict() + load_net_clean = {} + for k, v in pretrained_net.items(): + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + pretrained_net = load_net_clean + + tbd = [] + for k, v in crt_net.items(): + tbd.append(k) + + # directly copy + for k, v in crt_net.items(): + if k in pretrained_net and pretrained_net[k].size() == v.size(): + crt_net[k] = pretrained_net[k] + tbd.remove(k) + + crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] + crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] + + for k in tbd.copy(): + if 'RDB' in k: + ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[k] = pretrained_net[ori_k] + tbd.remove(k) + + crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] + crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] + crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] + crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] + crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] + crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] + crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] + crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] + crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] + crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] + + crt_model.load_state_dict(crt_net) + crt_model.eval() + return crt_model + def upscale_without_tiling(model, img): img = np.array(img) @@ -115,7 +139,7 @@ def esrgan_upscale(model, img): if opts.ESRGAN_tile == 0: return upscale_without_tiling(model, img) - grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) + grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) newtiles = [] scale_factor = 1 @@ -130,38 +154,7 @@ def esrgan_upscale(model, img): newrow.append([x * scale_factor, w * scale_factor, output]) newtiles.append([y * scale_factor, h * scale_factor, newrow]) - newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) - output = modules.images.combine_grid(newgrid) + newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, + grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) + output = images.combine_grid(newgrid) return output - - -class UpscalerESRGAN(modules.images.Upscaler): - def __init__(self, filename, title): - self.name = title - self.filename = filename - - def do_upscale(self, img): - model = load_model(self.filename, self.name) - if model is None: - return img - model.to(shared.device) - img = esrgan_upscale(model, img) - return img - - -def setup_model(dirname): - global model_path - global model_name - if not os.path.exists(model_path): - os.makedirs(model_path) - - model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"]) - if len(model_paths) == 0: - modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name)) - for file in model_paths: - name = modelloader.friendly_name(file) - try: - modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name)) - except Exception: - print(f"Error loading ESRGAN model: {file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) diff --git a/modules/extras.py b/modules/extras.py index af6e631f..d7d0fa54 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -66,29 +66,28 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" image = res - if upscaling_resize != 1.0: - def upscale(image, scaler_index, resize): - small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) - pixels = tuple(np.array(small).flatten().tolist()) - key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels + def upscale(image, scaler_index, resize): + small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) + pixels = tuple(np.array(small).flatten().tolist()) + key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels - c = cached_images.get(key) - if c is None: - upscaler = shared.sd_upscalers[scaler_index] - c = upscaler.upscale(image, image.width * resize, image.height * resize) - cached_images[key] = c + c = cached_images.get(key) + if c is None: + upscaler = shared.sd_upscalers[scaler_index] + c = upscaler.scaler.upscale(image, resize, upscaler.data_path) + cached_images[key] = c - return c + return c - info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" - res = upscale(image, extras_upscaler_1, upscaling_resize) + info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" + res = upscale(image, extras_upscaler_1, upscaling_resize) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - res2 = upscale(image, extras_upscaler_2, upscaling_resize) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" - res = Image.blend(res, res2, extras_upscaler_2_visibility) + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + res2 = upscale(image, extras_upscaler_2, upscaling_resize) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" + res = Image.blend(res, res2, extras_upscaler_2_visibility) - image = res + image = res while len(cached_images) > 2: del cached_images[next(iter(cached_images.keys()))] diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index ffb6960d..2bf8a1ee 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,24 +1,23 @@ import os import sys import traceback -from glob import glob -from modules import shared, devices -from modules.shared import cmd_opts -from modules.paths import script_path +import facexlib +import gfpgan + import modules.face_restoration from modules import shared, devices, modelloader from modules.paths import models_path model_dir = "GFPGAN" -cmd_dir = None +user_path = None model_path = os.path.join(models_path, model_dir) model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" - +have_gfpgan = False loaded_gfpgan_model = None -def gfpgan(): +def gfpgann(): global loaded_gfpgan_model global model_path if loaded_gfpgan_model is not None: @@ -28,14 +27,16 @@ def gfpgan(): if gfpgan_constructor is None: return None - models = modelloader.load_models(model_path, model_url, cmd_dir) - if len(models) != 0: + models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") + if len(models) == 1 and "http" in models[0]: + model_file = models[0] + elif len(models) != 0: latest_file = max(models, key=os.path.getctime) model_file = latest_file else: print("Unable to load gfpgan model!") return None - model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2, + model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) model.gfpgan.to(shared.device) loaded_gfpgan_model = model @@ -44,11 +45,12 @@ def gfpgan(): def gfpgan_fix_faces(np_image): - model = gfpgan() + model = gfpgann() if model is None: return np_image np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) + cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, + only_center_face=False, paste_back=True) np_image = gfpgan_output_bgr[:, :, ::-1] if shared.opts.face_restoration_unload: @@ -57,7 +59,6 @@ def gfpgan_fix_faces(np_image): return np_image -have_gfpgan = False gfpgan_constructor = None @@ -67,14 +68,33 @@ def setup_model(dirname): os.makedirs(model_path) try: - from modules.gfpgan_model_arch import GFPGANerr - global cmd_dir + from gfpgan import GFPGANer + from facexlib import detection, parsing + global user_path global have_gfpgan global gfpgan_constructor - cmd_dir = dirname + load_file_from_url_orig = gfpgan.utils.load_file_from_url + facex_load_file_from_url_orig = facexlib.detection.load_file_from_url + facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url + + def my_load_file_from_url(**kwargs): + print("Setting model_dir to " + model_path) + return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) + + def facex_load_file_from_url(**kwargs): + return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None)) + + def facex_load_file_from_url2(**kwargs): + return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None)) + + gfpgan.utils.load_file_from_url = my_load_file_from_url + facexlib.detection.load_file_from_url = facex_load_file_from_url + facexlib.parsing.load_file_from_url = facex_load_file_from_url2 + user_path = dirname + print("Have gfpgan should be true?") have_gfpgan = True - gfpgan_constructor = GFPGANerr + gfpgan_constructor = GFPGANer class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): def name(self): @@ -82,7 +102,9 @@ def setup_model(dirname): def restore(self, np_image): np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) + cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, + only_center_face=False, + paste_back=True) np_image = gfpgan_output_bgr[:, :, ::-1] return np_image diff --git a/modules/gfpgan_model_arch.py b/modules/gfpgan_model_arch.py deleted file mode 100644 index d81cea96..00000000 --- a/modules/gfpgan_model_arch.py +++ /dev/null @@ -1,150 +0,0 @@ -# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original... - -import cv2 -import os -import torch -from basicsr.utils import img2tensor, tensor2img -from basicsr.utils.download_util import load_file_from_url -from facexlib.utils.face_restoration_helper import FaceRestoreHelper -from torchvision.transforms.functional import normalize - -from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear -from gfpgan.archs.gfpganv1_arch import GFPGANv1 -from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean - -ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - - -class GFPGANerr(): - """Helper for restoration with GFPGAN. - - It will detect and crop faces, and then resize the faces to 512x512. - GFPGAN is used to restored the resized faces. - The background is upsampled with the bg_upsampler. - Finally, the faces will be pasted back to the upsample background image. - - Args: - model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). - upscale (float): The upscale of the final output. Default: 2. - arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. - channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. - bg_upsampler (nn.Module): The upsampler for the background. Default: None. - """ - - def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None): - self.upscale = upscale - self.bg_upsampler = bg_upsampler - - # initialize model - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device - # initialize the GFP-GAN - if arch == 'clean': - self.gfpgan = GFPGANv1Clean( - out_size=512, - num_style_feat=512, - channel_multiplier=channel_multiplier, - decoder_load_path=None, - fix_decoder=False, - num_mlp=8, - input_is_latent=True, - different_w=True, - narrow=1, - sft_half=True) - elif arch == 'bilinear': - self.gfpgan = GFPGANBilinear( - out_size=512, - num_style_feat=512, - channel_multiplier=channel_multiplier, - decoder_load_path=None, - fix_decoder=False, - num_mlp=8, - input_is_latent=True, - different_w=True, - narrow=1, - sft_half=True) - elif arch == 'original': - self.gfpgan = GFPGANv1( - out_size=512, - num_style_feat=512, - channel_multiplier=channel_multiplier, - decoder_load_path=None, - fix_decoder=True, - num_mlp=8, - input_is_latent=True, - different_w=True, - narrow=1, - sft_half=True) - elif arch == 'RestoreFormer': - from gfpgan.archs.restoreformer_arch import RestoreFormer - self.gfpgan = RestoreFormer() - # initialize face helper - self.face_helper = FaceRestoreHelper( - upscale, - face_size=512, - crop_ratio=(1, 1), - det_model='retinaface_resnet50', - save_ext='png', - use_parse=True, - device=self.device, - model_rootpath=model_dir) - - if model_path.startswith('https://'): - model_path = load_file_from_url( - url=model_path, model_dir=model_dir, progress=True, file_name=None) - loadnet = torch.load(model_path) - if 'params_ema' in loadnet: - keyname = 'params_ema' - else: - keyname = 'params' - self.gfpgan.load_state_dict(loadnet[keyname], strict=True) - self.gfpgan.eval() - self.gfpgan = self.gfpgan.to(self.device) - - @torch.no_grad() - def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5): - self.face_helper.clean_all() - - if has_aligned: # the inputs are already aligned - img = cv2.resize(img, (512, 512)) - self.face_helper.cropped_faces = [img] - else: - self.face_helper.read_image(img) - # get face landmarks for each face - self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5) - # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels - # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations. - # align and warp each face - self.face_helper.align_warp_face() - - # face restoration - for cropped_face in self.face_helper.cropped_faces: - # prepare data - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) - - try: - output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0] - # convert to image - restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) - except RuntimeError as error: - print(f'\tFailed inference for GFPGAN: {error}.') - restored_face = cropped_face - - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) - - if not has_aligned and paste_back: - # upsample the background - if self.bg_upsampler is not None: - # Now only support RealESRGAN for upsampling background - bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] - else: - bg_img = None - - self.face_helper.get_inverse_affine(None) - # paste each restored face to the input image - restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img) - return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img - else: - return self.face_helper.cropped_faces, self.face_helper.restored_faces, None diff --git a/modules/images.py b/modules/images.py index 9458bf8d..a6538dbe 100644 --- a/modules/images.py +++ b/modules/images.py @@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from fonts.ttf import Roboto import string -import modules.shared from modules import sd_samplers, shared from modules.shared import opts, cmd_opts @@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): cols = math.ceil((w - overlap) / non_overlap_width) rows = math.ceil((h - overlap) / non_overlap_height) - dx = (w - tile_w) / (cols-1) if cols > 1 else 0 - dy = (h - tile_h) / (rows-1) if rows > 1 else 0 + dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 + dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 grid = Grid([], tile_w, tile_h, w, h, overlap) for row in range(rows): @@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): for col in range(cols): x = int(col * dx) - if x+tile_w >= w: + if x + tile_w >= w: x = w - tile_w tile = image.crop((x, y, x + tile_w, y + tile_h)) @@ -85,8 +84,10 @@ def combine_grid(grid): r = r.astype(np.uint8) return Image.fromarray(r, 'L') - mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)) - mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)) + mask_w = make_mask_image( + np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)) + mask_h = make_mask_image( + np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)) combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) for y, h, row in grid.tiles: @@ -129,10 +130,12 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): def draw_texts(drawing, draw_x, draw_y, lines): for i, line in enumerate(lines): - drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") + drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, + fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") if not line.is_active: - drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4) + drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, + draw_y + line.size[1] // 2), fill=color_inactive, width=4) draw_y += line.size[1] + line_spacing @@ -171,7 +174,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1]) hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts] - ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts] + ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in + ver_texts] pad_top = max(hor_text_heights) + line_spacing * 2 @@ -202,8 +206,10 @@ def draw_prompt_matrix(im, width, height, all_prompts): prompts_horiz = prompts[:boundary] prompts_vert = prompts[boundary:] - hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))] - ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))] + hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in + range(1 << len(prompts_horiz))] + ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in + range(1 << len(prompts_vert))] return draw_grid_annotations(im, width, height, hor_texts, ver_texts) @@ -214,7 +220,8 @@ def resize_image(resize_mode, im, width, height): return im.resize((w, h), resample=LANCZOS) upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0] - return upscaler.upscale(im, w, h) + scale = w / im.width + return upscaler.scaler.upscale(im, scale) if resize_mode == 0: res = resize(im, width, height) @@ -244,11 +251,13 @@ def resize_image(resize_mode, im, width, height): if ratio < src_ratio: fill_height = height // 2 - src_h // 2 res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) - res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), + box=(0, fill_height + src_h)) elif ratio > src_ratio: fill_width = width // 2 - src_w // 2 res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) - res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), + box=(fill_width + src_w, 0)) return res @@ -256,7 +265,7 @@ def resize_image(resize_mode, im, width, height): invalid_filename_chars = '<>:"/\\|?*\n' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' -re_nonletters = re.compile(r'[\s'+string.punctuation+']+') +re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') max_filename_part_length = 128 @@ -283,7 +292,8 @@ def apply_filename_pattern(x, p, seed, prompt): words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0] if len(words) == 0: words = ["empty"] - x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False)) + x = x.replace("[prompt_words]", + sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False)) if p is not None: x = x.replace("[steps]", str(p.steps)) @@ -291,7 +301,8 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[width]", str(p.width)) x = x.replace("[height]", str(p.height)) x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False)) - x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) + x = x.replace("[sampler]", + sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) x = x.replace("[date]", datetime.date.today().isoformat()) @@ -303,6 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt): return x + def get_next_sequence_number(path, basename): """ Determines and returns the next sequence number to use when saving an image in the specified directory. @@ -316,7 +328,8 @@ def get_next_sequence_number(path, basename): prefix_length = len(basename) for p in os.listdir(path): if p.startswith(basename): - l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) + l = os.path.splitext(p[prefix_length:])[0].split( + '-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) try: result = max(int(l[0]), result) except ValueError: @@ -324,7 +337,10 @@ def get_next_sequence_number(path, basename): return result + 1 -def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""): + +def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, + no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, + forced_filename=None, suffix=""): if short_filename or prompt is None or seed is None: file_decoration = "" elif opts.save_to_dirs: @@ -361,7 +377,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i fullfn = "a.png" fullfn_without_extension = "a" for i in range(500): - fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}" + fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}" fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}") if not os.path.exists(fullfn): @@ -403,31 +419,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i file.write(info + "\n") -class Upscaler: - name = "Lanczos" - - def do_upscale(self, img): - return img - - def upscale(self, img, w, h): - for i in range(3): - if img.width >= w and img.height >= h: - break - - img = self.do_upscale(img) - - if img.width != w or img.height != h: - img = img.resize((int(w), int(h)), resample=LANCZOS) - - return img - - -class UpscalerNone(Upscaler): - name = "None" - - def upscale(self, img, w, h): - return img - - -modules.shared.sd_upscalers.append(UpscalerNone()) -modules.shared.sd_upscalers.append(Upscaler()) diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py index 4f9b1657..969d1a0d 100644 --- a/modules/ldsr_model.py +++ b/modules/ldsr_model.py @@ -1,74 +1,45 @@ import os import sys import traceback -from collections import namedtuple -from modules import shared, images, modelloader, paths -from modules.paths import models_path - -model_dir = "LDSR" -model_path = os.path.join(models_path, model_dir) -cmd_path = None -model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" -yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" +from basicsr.utils.download_util import load_file_from_url -LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"]) - -ldsr_models = [] -have_ldsr = False -LDSR_obj = None +from modules.upscaler import Upscaler, UpscalerData +from modules.ldsr_model_arch import LDSR +from modules import shared +from modules.paths import models_path -class UpscalerLDSR(images.Upscaler): - def __init__(self, steps): - self.steps = steps +class UpscalerLDSR(Upscaler): + def __init__(self, user_path): self.name = "LDSR" - - def do_upscale(self, img): - return upscale_with_ldsr(img) - - -def setup_model(dirname): - global cmd_path - global model_path - if not os.path.exists(model_path): - os.makedirs(model_path) - cmd_path = dirname - shared.sd_upscalers.append(UpscalerLDSR(100)) - - -def prepare_ldsr(): - path = paths.paths.get("LDSR", None) - if path is None: - return - global have_ldsr - global LDSR_obj - try: - from LDSR import LDSR - model_files = modelloader.load_models(model_path, model_url, cmd_path, dl_name="model.ckpt", ext_filter=[".ckpt"]) - yaml_files = modelloader.load_models(model_path, yaml_url, cmd_path, dl_name="project.yaml", ext_filter=[".yaml"]) - if len(model_files) != 0 and len(yaml_files) != 0: - model_file = model_files[0] - yaml_file = yaml_files[0] - have_ldsr = True - LDSR_obj = LDSR(model_file, yaml_file) - else: - return - - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - have_ldsr = False - - -def upscale_with_ldsr(image): - prepare_ldsr() - if not have_ldsr or LDSR_obj is None: - return image - - ddim_steps = shared.opts.ldsr_steps - pre_scale = shared.opts.ldsr_pre_down - post_scale = shared.opts.ldsr_post_down - - image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale) - return image + self.model_path = os.path.join(models_path, self.name) + self.user_path = user_path + self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" + self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" + super().__init__() + scaler_data = UpscalerData("LDSR", None, self) + self.scalers = [scaler_data] + + def load_model(self, path: str): + model = load_file_from_url(url=self.model_url, model_dir=self.model_path, + file_name="model.pth", progress=True) + yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path, + file_name="project.yaml", progress=True) + + try: + return LDSR(model, yaml) + + except Exception: + print("Error importing LDSR:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + def do_upscale(self, img, path): + ldsr = self.load_model(path) + if ldsr is None: + print("NO LDSR!") + return img + ddim_steps = shared.opts.ldsr_steps + pre_scale = shared.opts.ldsr_pre_down + return ldsr.super_resolution(img, ddim_steps, self.scale) diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py new file mode 100644 index 00000000..8fe87c6a --- /dev/null +++ b/modules/ldsr_model_arch.py @@ -0,0 +1,223 @@ +import gc +import time +import warnings + +import numpy as np +import torch +import torchvision +from PIL import Image +from einops import rearrange, repeat +from omegaconf import OmegaConf + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config, ismap + +warnings.filterwarnings("ignore", category=UserWarning) + + +# Create LDSR Class +class LDSR: + def load_model_from_config(self, half_attention): + print(f"Loading model from {self.modelPath}") + pl_sd = torch.load(self.modelPath, map_location="cpu") + sd = pl_sd["state_dict"] + config = OmegaConf.load(self.yamlPath) + model = instantiate_from_config(config.model) + model.load_state_dict(sd, strict=False) + model.cuda() + if half_attention: + model = model.half() + + model.eval() + return {"model": model} + + def __init__(self, model_path, yaml_path): + self.modelPath = model_path + self.yamlPath = yaml_path + + @staticmethod + def run(model, selected_path, custom_steps, eta): + example = get_cond(selected_path) + + n_runs = 1 + guider = None + ckwargs = None + ddim_use_x0_pred = False + temperature = 1. + eta = eta + custom_shape = None + + height, width = example["image"].shape[1:3] + split_input = height >= 128 and width >= 128 + + if split_input: + ks = 128 + stride = 64 + vqf = 4 # + model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), + "vqf": vqf, + "patch_distributed_vq": True, + "tie_braker": False, + "clip_max_weight": 0.5, + "clip_min_weight": 0.01, + "clip_max_tie_weight": 0.5, + "clip_min_tie_weight": 0.01} + else: + if hasattr(model, "split_input_params"): + delattr(model, "split_input_params") + + x_t = None + logs = None + for n in range(n_runs): + if custom_shape is not None: + x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) + x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) + + logs = make_convolutional_sample(example, model, + custom_steps=custom_steps, + eta=eta, quantize_x0=False, + custom_shape=custom_shape, + temperature=temperature, noise_dropout=0., + corrector=guider, corrector_kwargs=ckwargs, x_T=x_t, + ddim_use_x0_pred=ddim_use_x0_pred + ) + return logs + + def super_resolution(self, image, steps=100, target_scale=2, half_attention=False): + model = self.load_model_from_config(half_attention) + + # Run settings + diffusion_steps = int(steps) + eta = 1.0 + + down_sample_method = 'Lanczos' + + gc.collect() + torch.cuda.empty_cache() + + im_og = image + width_og, height_og = im_og.size + # If we can adjust the max upscale size, then the 4 below should be our variable + print("Foo") + down_sample_rate = target_scale / 4 + print(f"Downsample rate is {down_sample_rate}") + width_downsampled_pre = width_og * down_sample_rate + height_downsampled_pre = height_og * down_sample_method + + if down_sample_rate != 1: + print( + f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') + im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) + else: + print(f"Down sample rate is 1 from {target_scale} / 4") + logs = self.run(model["model"], im_og, diffusion_steps, eta) + + sample = logs["sample"] + sample = sample.detach().cpu() + sample = torch.clamp(sample, -1., 1.) + sample = (sample + 1.) / 2. * 255 + sample = sample.numpy().astype(np.uint8) + sample = np.transpose(sample, (0, 2, 3, 1)) + a = Image.fromarray(sample[0]) + + del model + gc.collect() + torch.cuda.empty_cache() + print(f'Processing finished!') + return a + + +def get_cond(selected_path): + example = dict() + up_f = 4 + c = selected_path.convert('RGB') + c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) + c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], + antialias=True) + c_up = rearrange(c_up, '1 c h w -> 1 h w c') + c = rearrange(c, '1 c h w -> 1 h w c') + c = 2. * c - 1. + + c = c.to(torch.device("cuda")) + example["LR_image"] = c + example["image"] = c_up + + return example + + +@torch.no_grad() +def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, + mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None, + corrector_kwargs=None, x_t=None + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + print(f"Sampling with eta = {eta}; steps: {steps}") + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, + normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, + mask=mask, x0=x0, temperature=temperature, verbose=False, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, x_t=x_t) + + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, + corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): + log = dict() + + z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=not (hasattr(model, 'split_input_params') + and model.cond_stage_key == 'coordinates_bbox'), + return_original_cond=True) + + if custom_shape is not None: + z = torch.randn(custom_shape) + print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") + + z0 = None + + log["input"] = x + log["reconstruction"] = xrec + + if ismap(xc): + log["original_conditioning"] = model.to_rgb(xc) + if hasattr(model, 'cond_stage_key'): + log[model.cond_stage_key] = model.to_rgb(xc) + + else: + log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_model: + log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_key == 'class_label': + log[model.cond_stage_key] = xc[model.cond_stage_key] + + with model.ema_scope("Plotting"): + t0 = time.time() + + sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, + eta=eta, + quantize_x0=quantize_x0, mask=None, x0=z0, + temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs, + x_t=x_T) + t1 = time.time() + + if ddim_use_x0_pred: + sample = intermediates['pred_x0'][-1] + + x_sample = model.decode_first_stage(sample) + + try: + x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) + log["sample_noquant"] = x_sample_noquant + log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) + except: + pass + + log["sample"] = x_sample + log["time"] = t1 - t0 + + return log diff --git a/modules/modelloader.py b/modules/modelloader.py index 3bd1de4d..6de65c69 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,34 +1,36 @@ import os import shutil +import importlib from urllib.parse import urlparse from basicsr.utils.download_util import load_file_from_url +from modules import shared +from modules.upscaler import Upscaler from modules.paths import script_path, models_path -def load_models(model_path: str, model_url: str = None, command_path: str = None, dl_name: str = None, existing=None, - ext_filter=None) -> list: +def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. - @param dl_name: The file name to use for downloading a model. If not specified, it will be used from the URL. - @param model_url: If specified, attempt to download model from the given URL. + @param download_name: Specify to download from model_url immediately. + @param model_url: If no other models are found, this will be downloaded on upscale. @param model_path: The location to store/find models in. @param command_path: A command-line argument to search for models in first. - @param existing: An array of existing model paths. @param ext_filter: An optional list of filename extensions to filter by @return: A list of paths containing the desired model(s) """ + output = [] + if ext_filter is None: ext_filter = [] - if existing is None: - existing = [] try: places = [] if command_path is not None and command_path != model_path: pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') if os.path.exists(pretrained_path): + print(f"Appending path: {pretrained_path}") places.append(pretrained_path) elif os.path.exists(command_path): places.append(command_path) @@ -36,26 +38,24 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None for place in places: if os.path.exists(place): for file in os.listdir(place): - if os.path.isdir(file): + full_path = os.path.join(place, file) + if os.path.isdir(full_path): continue if len(ext_filter) != 0: model_name, extension = os.path.splitext(file) if extension not in ext_filter: continue - if file not in existing: - path = os.path.join(place, file) - existing.append(path) - if model_url is not None and len(existing) == 0: - if dl_name is not None: - model_file = load_file_from_url(url=model_url, model_dir=model_path, file_name=dl_name, progress=True) + if file not in output: + output.append(full_path) + if model_url is not None and len(output) == 0: + if download_name is not None: + dl = load_file_from_url(model_url, model_path, True, download_name) + output.append(dl) else: - model_file = load_file_from_url(url=model_url, model_dir=model_path, progress=True) - - if os.path.exists(model_file) and os.path.isfile(model_file) and model_file not in existing: - existing.append(model_file) + output.append(model_url) except: pass - return existing + return output def friendly_name(file: str): @@ -110,4 +110,38 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): print(f"Removing empty folder: {src_path}") shutil.rmtree(src_path, True) except: - pass \ No newline at end of file + pass + + +def load_upscalers(): + datas = [] + for cls in Upscaler.__subclasses__(): + name = cls.__name__ + module_name = cls.__module__ + print(f"Class: {name} and {module_name}") + module = importlib.import_module(module_name) + class_ = getattr(module, name) + cmd_name = f"{name.lower().replace('upscaler', '')}-models-path" + print(f"CMD Name: {cmd_name}") + opt_string = None + try: + opt_string = shared.opts.__getattr__(cmd_name) + except: + pass + scaler = class_(opt_string) + for child in scaler.scalers: + print(f"Appending {child.name}") + datas.append(child) + + shared.sd_upscalers = datas + + # for scaler in subclasses: + # print(f"Found scaler: {type(scaler).__name__}") + # try: + # scaler = scaler() + # for child in scaler.scalers: + # print(f"Appending {child.name}") + # datas.append[child] + # except: + # pass + # shared.sd_upscalers = datas diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 458bf678..0a2eb896 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,64 +1,135 @@ import os import sys import traceback -from collections import namedtuple import numpy as np from PIL import Image from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer -import modules.images +from modules.upscaler import Upscaler, UpscalerData from modules.paths import models_path from modules.shared import cmd_opts, opts -model_dir = "RealESRGAN" -model_path = os.path.join(models_path, model_dir) -cmd_dir = None -RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"]) -realesrgan_models = [] -have_realesrgan = False +class UpscalerRealESRGAN(Upscaler): + def __init__(self, path): + self.name = "RealESRGAN" + self.model_path = os.path.join(models_path, self.name) + self.user_path = path + super().__init__() + try: + from basicsr.archs.rrdbnet_arch import RRDBNet + from realesrgan import RealESRGANer + from realesrgan.archs.srvgg_arch import SRVGGNetCompact + self.enable = True + self.scalers = [] + scalers = self.load_models(path) + for scaler in scalers: + if scaler.name in opts.realesrgan_enabled_models: + self.scalers.append(scaler) + + except Exception: + print("Error importing Real-ESRGAN:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + self.enable = False + self.scalers = [] + + def do_upscale(self, img, path): + if not self.enable: + return img + + info = self.load_model(path) + if not os.path.exists(info.data_path): + print("Unable to load RealESRGAN model: %s" % info.name) + return img + + upsampler = RealESRGANer( + scale=info.scale, + model_path=info.data_path, + model=info.model(), + half=not cmd_opts.no_half, + tile=opts.ESRGAN_tile, + tile_pad=opts.ESRGAN_tile_overlap, + ) + + upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] + + image = Image.fromarray(upsampled) + return image + + def load_model(self, path): + try: + info = None + for scaler in self.scalers: + if scaler.data_path == path: + info = scaler + + if info is None: + print(f"Unable to find model info: {path}") + return None + + model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True) + info.data_path = model_file + return info + except Exception as e: + print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None -def get_realesrgan_models(): + def load_models(self, _): + return get_realesrgan_models(self) + + +def get_realesrgan_models(scaler): try: from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan.archs.srvgg_arch import SRVGGNetCompact models = [ - RealesrganModelInfo( - name="Real-ESRGAN General x4x3", - location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", - netscale=4, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') + UpscalerData( + name="R-ESRGAN General 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3" + ".pth", + scale=4, + upscaler=scaler, + model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, + act_type='prelu') ), - RealesrganModelInfo( - name="Real-ESRGAN General WDN x4x3", - location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", - netscale=4, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') + UpscalerData( + name="R-ESRGAN General WDN 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", + scale=4, + upscaler=scaler, + model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, + act_type='prelu') ), - RealesrganModelInfo( - name="Real-ESRGAN AnimeVideo", - location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", - netscale=4, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') + UpscalerData( + name="R-ESRGAN AnimeVideo", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", + scale=4, + upscaler=scaler, + model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, + act_type='prelu') ), - RealesrganModelInfo( - name="Real-ESRGAN 4x plus", - location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", - netscale=4, + UpscalerData( + name="R-ESRGAN 4x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + scale=4, + upscaler=scaler, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) ), - RealesrganModelInfo( - name="Real-ESRGAN 4x plus anime 6B", - location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", - netscale=4, + UpscalerData( + name="R-ESRGAN 4x+ Anime6B", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + scale=4, + upscaler=scaler, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) ), - RealesrganModelInfo( - name="Real-ESRGAN 2x plus", - location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", - netscale=2, + UpscalerData( + name="R-ESRGAN 2x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + scale=2, + upscaler=scaler, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) ), ] @@ -66,69 +137,3 @@ def get_realesrgan_models(): except Exception as e: print("Error making Real-ESRGAN models list:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - - -class UpscalerRealESRGAN(modules.images.Upscaler): - def __init__(self, upscaling, model_index): - self.upscaling = upscaling - self.model_index = model_index - self.name = realesrgan_models[model_index].name - - def do_upscale(self, img): - return upscale_with_realesrgan(img, self.upscaling, self.model_index) - - -def setup_model(dirname): - global model_path - if not os.path.exists(model_path): - os.makedirs(model_path) - - global realesrgan_models - global have_realesrgan - if model_path != dirname: - model_path = dirname - try: - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan import RealESRGANer - from realesrgan.archs.srvgg_arch import SRVGGNetCompact - - realesrgan_models = get_realesrgan_models() - have_realesrgan = True - - for i, model in enumerate(realesrgan_models): - if model.name in opts.realesrgan_enabled_models: - modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i)) - - except Exception: - print("Error importing Real-ESRGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - - realesrgan_models = [RealesrganModelInfo('None', '', 0, None)] - have_realesrgan = False - - -def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index): - if not have_realesrgan: - return image - - info = realesrgan_models[RealESRGAN_model_index] - - model = info.model() - model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True) - if not os.path.exists(model_file): - print("Unable to load RealESRGAN model: %s" % info.name) - return image - - upsampler = RealESRGANer( - scale=info.netscale, - model_path=info.location, - model=model, - half=not cmd_opts.no_half, - tile=opts.ESRGAN_tile, - tile_pad=opts.ESRGAN_tile_overlap, - ) - - upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0] - - image = Image.fromarray(upsampled) - return image diff --git a/modules/sd_models.py b/modules/sd_models.py index 89b7d276..23826727 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -50,7 +50,7 @@ def setup_model(dirname): if not os.path.exists(model_path): os.makedirs(model_path) checkpoints_list.clear() - model_list = modelloader.load_models(model_path, model_url, dirname, model_name, ext_filter=".ckpt") + model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=dirname, download_name=model_name, ext_filter=".ckpt") cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): @@ -68,6 +68,7 @@ def setup_model(dirname): def model_hash(filename): try: + print(f"Opening: {filename}") with open(filename, "rb") as file: import hashlib m = hashlib.sha256() diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 666ee1ee..cfc3ee40 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -154,9 +154,9 @@ class VanillaStableDiffusionSampler: # existing code fails with cetin step counts, like 9 try: - samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta) + samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_t=x, eta=p.ddim_eta) except Exception: - samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta) + samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_t=x, eta=p.ddim_eta) return samples_ddim diff --git a/modules/shared.py b/modules/shared.py index c27079eb..4c31039d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,18 +1,19 @@ -import sys import argparse +import datetime import json import os +import sys + import gradio as gr import tqdm -import datetime import modules.artists -from modules.paths import script_path, sd_path -from modules.devices import get_optimal_device -import modules.styles import modules.interrogate import modules.memmon import modules.sd_models +import modules.styles +from modules.devices import get_optimal_device +from modules.paths import script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file @@ -38,6 +39,7 @@ parser.add_argument("--share", action='store_true', help="use share=True for gra parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) +parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) @@ -111,7 +113,7 @@ face_restorers = [] def realesrgan_models_names(): import modules.realesrgan_model - return [x.name for x in modules.realesrgan_model.get_realesrgan_models()] + return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)] class OptionInfo: @@ -176,13 +178,11 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), - "realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}), - "ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}), - "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), })) diff --git a/modules/swinir_model.py b/modules/swinir_model.py index f515779e..ea7b6301 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -1,92 +1,91 @@ import contextlib import os -import sys -import traceback import numpy as np import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url -import modules.images from modules import modelloader from modules.paths import models_path from modules.shared import cmd_opts, opts, device from modules.swinir_model_arch import SwinIR as net +from modules.upscaler import Upscaler, UpscalerData -model_dir = "SwinIR" -model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" -model_name = "SwinIR x4" -model_path = os.path.join(models_path, model_dir) -cmd_path = "" precision_scope = ( torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext ) -def load_model(path, scale=4): - global model_path - global model_name - if "http" in path: - dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth") - filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True) - else: - filename = path - if filename is None or not os.path.exists(filename): - return None - model = net( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="3conv", - ) - - pretrained_model = torch.load(filename) - model.load_state_dict(pretrained_model["params_ema"], strict=True) - if not cmd_opts.no_half: - model = model.half() - return model - - -def setup_model(dirname): - global model_path - global model_name - global cmd_path - if not os.path.exists(model_path): - os.makedirs(model_path) - cmd_path = dirname - model_file = "" - try: - models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path) - - if len(models) != 0: - model_file = models[0] - name = modelloader.friendly_name(model_file) - else: - # Add the "default" model if none are found. - model_file = model_url - name = model_name +class UpscalerSwinIR(Upscaler): + def __init__(self, dirname): + self.name = "SwinIR" + self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ + "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ + "-L_x4_GAN.pth " + self.model_name = "SwinIR 4x" + self.model_path = os.path.join(models_path, self.name) + self.user_path = dirname + super().__init__() + scalers = [] + model_files = self.find_models(ext_filter=[".pt", ".pth"]) + for model in model_files: + if "http" in model: + name = self.model_name + else: + name = modelloader.friendly_name(model) + model_data = UpscalerData(name, model, self) + scalers.append(model_data) + self.scalers = scalers + + def do_upscale(self, img, model_file): + model = self.load_model(model_file) + if model is None: + return img + model = model.to(device) + img = upscale(img, model) + try: + torch.cuda.empty_cache() + except: + pass + return img - modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name)) - except Exception: - print(f"Error loading SwinIR model: {model_file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + def load_model(self, path, scale=4): + if "http" in path: + dl_name = "%s%s" % (self.name.replace(" ", "_"), ".pth") + filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) + else: + filename = path + if filename is None or not os.path.exists(filename): + return None + model = net( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=240, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="3conv", + ) + + pretrained_model = torch.load(filename) + model.load_state_dict(pretrained_model["params_ema"], strict=True) + if not cmd_opts.no_half: + model = model.half() + return model def upscale( - img, - model, - tile=opts.SWIN_tile, - tile_overlap=opts.SWIN_tile_overlap, - window_size=8, - scale=4, + img, + model, + tile=opts.SWIN_tile, + tile_overlap=opts.SWIN_tile_overlap, + window_size=8, + scale=4, ): img = np.array(img) img = img[:, :, ::-1] @@ -125,34 +124,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale): for h_idx in h_idx_list: for w_idx in w_idx_list: - in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile] + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) E[ - ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf ].add_(out_patch) W[ - ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf ].add_(out_patch_mask) output = E.div_(W) return output - - -class UpscalerSwin(modules.images.Upscaler): - def __init__(self, filename, title): - self.name = title - self.filename = filename - - def do_upscale(self, img): - model = load_model(self.filename) - if model is None: - return img - model = model.to(device) - img = upscale(img, model) - try: - torch.cuda.empty_cache() - except: - pass - return img \ No newline at end of file diff --git a/modules/upscaler.py b/modules/upscaler.py new file mode 100644 index 00000000..d698282f --- /dev/null +++ b/modules/upscaler.py @@ -0,0 +1,121 @@ +import os +from abc import abstractmethod + +import PIL +import numpy as np +import torch +from PIL import Image + +import modules.shared +from modules import modelloader, shared + +LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +from modules.paths import models_path + + +class Upscaler: + name = None + model_path = None + model_name = None + model_url = None + enable = True + filter = None + model = None + user_path = None + scalers: [] + tile = True + + def __init__(self, create_dirs=False): + self.mod_pad_h = None + self.tile_size = modules.shared.opts.ESRGAN_tile + self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap + self.device = modules.shared.device + self.img = None + self.output = None + self.scale = 1 + self.half = not modules.shared.cmd_opts.no_half + self.pre_pad = 0 + self.mod_scale = None + if self.name is not None and create_dirs: + self.model_path = os.path.join(models_path, self.name) + if not os.path.exists(self.model_path): + os.makedirs(self.model_path) + + try: + import cv2 + self.can_tile = True + except: + pass + + @abstractmethod + def do_upscale(self, img: PIL.Image, selected_model: str): + return img + + def upscale(self, img: PIL.Image, scale: int, selected_model: str = None): + self.scale = scale + dest_w = img.width * scale + dest_h = img.height * scale + for i in range(3): + if img.width >= dest_w and img.height >= dest_h: + break + img = self.do_upscale(img, selected_model) + if img.width != dest_w or img.height != dest_h: + img = img.resize(dest_w, dest_h, resample=LANCZOS) + + return img + + @abstractmethod + def load_model(self, path: str): + pass + + def find_models(self, ext_filter=None) -> list: + return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path) + + def update_status(self, prompt): + print(f"\nextras: {prompt}", file=shared.progress_print_out) + + +class UpscalerData: + name = None + data_path = None + scale: int = 4 + scaler: Upscaler = None + model: None + + def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): + self.name = name + self.data_path = path + self.scaler = upscaler + self.scale = scale + self.model = model + + +class UpscalerNone(Upscaler): + name = "None" + scalers = [] + + def load_model(self, path): + pass + + def do_upscale(self, img, selected_model=None): + return img + + def __init__(self, dirname=None): + super().__init__(False) + self.scalers = [UpscalerData("None", None, self)] + + +class UpscalerLanczos(Upscaler): + scalers = [] + + def do_upscale(self, img, selected_model=None): + return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) + + def load_model(self, _): + pass + + def __init__(self, dirname=None): + super().__init__(False) + self.name = "Lanczos" + self.scalers = [UpscalerData("Lanczos", None, self)] + diff --git a/webui.py b/webui.py index 76d392a2..be1bc769 100644 --- a/webui.py +++ b/webui.py @@ -1,9 +1,10 @@ import os import signal import threading - +import modules.paths import modules.codeformer_model as codeformer import modules.esrgan_model as esrgan +import modules.bsrgan_model as bsrgan import modules.extras import modules.face_restoration import modules.gfpgan_model as gfpgan @@ -27,11 +28,7 @@ modules.sd_models.setup_model(cmd_opts.stablediffusion_models_path) codeformer.setup_model(cmd_opts.codeformer_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path) shared.face_restorers.append(modules.face_restoration.FaceRestoration()) - -esrgan.setup_model(cmd_opts.esrgan_models_path) -swinir.setup_model(cmd_opts.swinir_models_path) -realesrgan.setup_model(cmd_opts.realesrgan_models_path) -ldsr.setup_model(cmd_opts.ldsr_models_path) +modelloader.load_upscalers() queue_lock = threading.Lock() -- cgit v1.2.3 From d1f098540ad1dbc2abb8d04322634efba650b631 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 30 Sep 2022 11:42:40 +0300 Subject: remove unwanted formatting/functionality from the PR --- launch.py | 7 +-- modules/esrgan_model.py | 123 ++++++++++++++++++++++---------------------- modules/extras.py | 35 +++++++------ modules/gfpgan_model.py | 12 ++--- modules/images.py | 37 +++++-------- modules/ldsr_model_arch.py | 1 - modules/modelloader.py | 9 +++- modules/realesrgan_model.py | 12 ++--- modules/sd_models.py | 56 ++++++-------------- modules/shared.py | 8 +-- webui.py | 2 +- 11 files changed, 127 insertions(+), 175 deletions(-) (limited to 'modules/shared.py') diff --git a/launch.py b/launch.py index 3b8d8f23..d2793ed2 100644 --- a/launch.py +++ b/launch.py @@ -1,5 +1,4 @@ # this scripts installs necessary requirements and launches main program in webui.py -import shutil import subprocess import os import sys @@ -119,11 +118,7 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) -if os.path.isdir(repo_dir('latent-diffusion')): - try: - shutil.rmtree(repo_dir('latent-diffusion')) - except: - pass + if not is_installed("lpips"): run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index ce841aa4..ea91abfe 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -13,6 +13,63 @@ from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts +def fix_model_layers(crt_model, pretrained_net): + # this code is adapted from https://github.com/xinntao/ESRGAN + if 'conv_first.weight' in pretrained_net: + return pretrained_net + + if 'model.0.weight' not in pretrained_net: + is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] + if is_realesrgan: + raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") + else: + raise Exception("The file is not a ESRGAN model.") + + crt_net = crt_model.state_dict() + load_net_clean = {} + for k, v in pretrained_net.items(): + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + pretrained_net = load_net_clean + + tbd = [] + for k, v in crt_net.items(): + tbd.append(k) + + # directly copy + for k, v in crt_net.items(): + if k in pretrained_net and pretrained_net[k].size() == v.size(): + crt_net[k] = pretrained_net[k] + tbd.remove(k) + + crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] + crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] + + for k in tbd.copy(): + if 'RDB' in k: + ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[k] = pretrained_net[ori_k] + tbd.remove(k) + + crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] + crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] + crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] + crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] + crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] + crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] + crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] + crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] + crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] + crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] + + return crt_net + class UpscalerESRGAN(Upscaler): def __init__(self, dirname): self.name = "ESRGAN" @@ -28,14 +85,12 @@ class UpscalerESRGAN(Upscaler): scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) scalers.append(scaler_data) for file in model_paths: - print(f"File: {file}") if "http" in file: name = self.model_name else: name = modelloader.friendly_name(file) scaler_data = UpscalerData(name, file, self, 4) - print(f"ESRGAN: Adding scaler {name}") self.scalers.append(scaler_data) def do_upscale(self, img, selected_model): @@ -56,67 +111,14 @@ class UpscalerESRGAN(Upscaler): if not os.path.exists(filename) or filename is None: print("Unable to load %s from %s" % (self.model_path, filename)) return None - # this code is adapted from https://github.com/xinntao/ESRGAN + pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) - if 'conv_first.weight' in pretrained_net: - crt_model.load_state_dict(pretrained_net) - return crt_model - - if 'model.0.weight' not in pretrained_net: - is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[ - "params_ema"] - if is_realesrgan: - raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") - else: - raise Exception("The file is not a ESRGAN model.") - - crt_net = crt_model.state_dict() - load_net_clean = {} - for k, v in pretrained_net.items(): - if k.startswith('module.'): - load_net_clean[k[7:]] = v - else: - load_net_clean[k] = v - pretrained_net = load_net_clean - - tbd = [] - for k, v in crt_net.items(): - tbd.append(k) - - # directly copy - for k, v in crt_net.items(): - if k in pretrained_net and pretrained_net[k].size() == v.size(): - crt_net[k] = pretrained_net[k] - tbd.remove(k) - - crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] - crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] - - for k in tbd.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[k] = pretrained_net[ori_k] - tbd.remove(k) - - crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] - crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] - crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] - crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] - crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] - crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] - crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] - crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] - crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] - crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] - - crt_model.load_state_dict(crt_net) + pretrained_net = fix_model_layers(crt_model, pretrained_net) + crt_model.load_state_dict(pretrained_net) crt_model.eval() + return crt_model @@ -154,7 +156,6 @@ def esrgan_upscale(model, img): newrow.append([x * scale_factor, w * scale_factor, output]) newtiles.append([y * scale_factor, h * scale_factor, newrow]) - newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, - grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) + newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) output = images.combine_grid(newgrid) return output diff --git a/modules/extras.py b/modules/extras.py index 1d4e9fa8..1bff5874 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -67,28 +67,29 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" image = res - def upscale(image, scaler_index, resize): - small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) - pixels = tuple(np.array(small).flatten().tolist()) - key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels + if upscaling_resize != 1.0: + def upscale(image, scaler_index, resize): + small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) + pixels = tuple(np.array(small).flatten().tolist()) + key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels - c = cached_images.get(key) - if c is None: - upscaler = shared.sd_upscalers[scaler_index] - c = upscaler.scaler.upscale(image, resize, upscaler.data_path) - cached_images[key] = c + c = cached_images.get(key) + if c is None: + upscaler = shared.sd_upscalers[scaler_index] + c = upscaler.scaler.upscale(image, resize, upscaler.data_path) + cached_images[key] = c - return c + return c - info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" - res = upscale(image, extras_upscaler_1, upscaling_resize) + info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n" + res = upscale(image, extras_upscaler_1, upscaling_resize) - if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: - res2 = upscale(image, extras_upscaler_2, upscaling_resize) - info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" - res = Image.blend(res, res2, extras_upscaler_2_visibility) + if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0: + res2 = upscale(image, extras_upscaler_2, upscaling_resize) + info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n" + res = Image.blend(res, res2, extras_upscaler_2_visibility) - image = res + image = res while len(cached_images) > 2: del cached_images[next(iter(cached_images.keys()))] diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 2bf8a1ee..bb30d733 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -36,8 +36,7 @@ def gfpgann(): else: print("Unable to load gfpgan model!") return None - model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, - bg_upsampler=None) + model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) model.gfpgan.to(shared.device) loaded_gfpgan_model = model @@ -49,8 +48,7 @@ def gfpgan_fix_faces(np_image): if model is None: return np_image np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, - only_center_face=False, paste_back=True) + cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) np_image = gfpgan_output_bgr[:, :, ::-1] if shared.opts.face_restoration_unload: @@ -79,7 +77,6 @@ def setup_model(dirname): facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url def my_load_file_from_url(**kwargs): - print("Setting model_dir to " + model_path) return load_file_from_url_orig(**dict(kwargs, model_dir=model_path)) def facex_load_file_from_url(**kwargs): @@ -92,7 +89,6 @@ def setup_model(dirname): facexlib.detection.load_file_from_url = facex_load_file_from_url facexlib.parsing.load_file_from_url = facex_load_file_from_url2 user_path = dirname - print("Have gfpgan should be true?") have_gfpgan = True gfpgan_constructor = GFPGANer @@ -102,9 +98,7 @@ def setup_model(dirname): def restore(self, np_image): np_image_bgr = np_image[:, :, ::-1] - cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, - only_center_face=False, - paste_back=True) + cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) np_image = gfpgan_output_bgr[:, :, ::-1] return np_image diff --git a/modules/images.py b/modules/images.py index 6430cfec..e89c44b2 100644 --- a/modules/images.py +++ b/modules/images.py @@ -84,10 +84,8 @@ def combine_grid(grid): r = r.astype(np.uint8) return Image.fromarray(r, 'L') - mask_w = make_mask_image( - np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)) - mask_h = make_mask_image( - np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)) + mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)) + mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)) combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) for y, h, row in grid.tiles: @@ -130,12 +128,10 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): def draw_texts(drawing, draw_x, draw_y, lines): for i, line in enumerate(lines): - drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, - fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") + drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") if not line.is_active: - drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, - draw_y + line.size[1] // 2), fill=color_inactive, width=4) + drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4) draw_y += line.size[1] + line_spacing @@ -206,10 +202,8 @@ def draw_prompt_matrix(im, width, height, all_prompts): prompts_horiz = prompts[:boundary] prompts_vert = prompts[boundary:] - hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in - range(1 << len(prompts_horiz))] - ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in - range(1 << len(prompts_vert))] + hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))] + ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))] return draw_grid_annotations(im, width, height, hor_texts, ver_texts) @@ -259,13 +253,11 @@ def resize_image(resize_mode, im, width, height): if ratio < src_ratio: fill_height = height // 2 - src_h // 2 res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) - res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), - box=(0, fill_height + src_h)) + res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h)) elif ratio > src_ratio: fill_width = width // 2 - src_w // 2 res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) - res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), - box=(fill_width + src_w, 0)) + res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0)) return res @@ -300,8 +292,7 @@ def apply_filename_pattern(x, p, seed, prompt): words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0] if len(words) == 0: words = ["empty"] - x = x.replace("[prompt_words]", - sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False)) + x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False)) if p is not None: x = x.replace("[steps]", str(p.steps)) @@ -309,8 +300,7 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[width]", str(p.width)) x = x.replace("[height]", str(p.height)) x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False)) - x = x.replace("[sampler]", - sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) + x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) x = x.replace("[date]", datetime.date.today().isoformat()) @@ -336,8 +326,7 @@ def get_next_sequence_number(path, basename): prefix_length = len(basename) for p in os.listdir(path): if p.startswith(basename): - l = os.path.splitext(p[prefix_length:])[0].split( - '-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) + l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) try: result = max(int(l[0]), result) except ValueError: @@ -346,9 +335,7 @@ def get_next_sequence_number(path, basename): return result + 1 -def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, - no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, - forced_filename=None, suffix=""): +def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""): if short_filename or prompt is None or seed is None: file_decoration = "" elif opts.save_to_dirs: diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py index f8f3c3d3..7faac6e1 100644 --- a/modules/ldsr_model_arch.py +++ b/modules/ldsr_model_arch.py @@ -125,7 +125,6 @@ class LDSR: del model gc.collect() torch.cuda.empty_cache() - print(f'Processing finished!') return a diff --git a/modules/modelloader.py b/modules/modelloader.py index b3e6dc36..1106aeb7 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -25,8 +25,10 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if ext_filter is None: ext_filter = [] + try: places = [] + if command_path is not None and command_path != model_path: pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') if os.path.exists(pretrained_path): @@ -34,7 +36,9 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None places.append(pretrained_path) elif os.path.exists(command_path): places.append(command_path) + places.append(model_path) + for place in places: if os.path.exists(place): for file in os.listdir(place): @@ -47,14 +51,17 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None continue if file not in output: output.append(full_path) + if model_url is not None and len(output) == 0: if download_name is not None: dl = load_file_from_url(model_url, model_path, True, download_name) output.append(dl) else: output.append(model_url) - except: + + except Exception: pass + return output diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 0a2eb896..dc0123e0 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -88,28 +88,24 @@ def get_realesrgan_models(scaler): models = [ UpscalerData( name="R-ESRGAN General 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3" - ".pth", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", scale=4, upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, - act_type='prelu') + model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') ), UpscalerData( name="R-ESRGAN General WDN 4xV3", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", scale=4, upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, - act_type='prelu') + model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') ), UpscalerData( name="R-ESRGAN AnimeVideo", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", scale=4, upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, - act_type='prelu') + model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') ), UpscalerData( name="R-ESRGAN 4x+", diff --git a/modules/sd_models.py b/modules/sd_models.py index 4b9000a4..caa85d5e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -12,10 +12,10 @@ from modules import shared, modelloader from modules.paths import models_path model_dir = "Stable-diffusion" -model_path = os.path.join(models_path, model_dir) +model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_name = "sd-v1-4.ckpt" model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1" -user_dir = None +user_dir: (str | None) = None CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} @@ -30,26 +30,8 @@ except Exception: pass -def modeltitle(path, h): - abspath = os.path.abspath(path) - - if abspath.startswith(model_dir): - name = abspath.replace(model_dir, '') - else: - name = os.path.basename(path) - - if name.startswith("\\") or name.startswith("/"): - name = name[1:] - - return f'{name} [{h}]' - - def setup_model(dirname): - global model_path - global model_name - global model_url global user_dir - global model_list user_dir = dirname if not os.path.exists(model_path): os.makedirs(model_path) @@ -62,21 +44,16 @@ def checkpoint_tiles(): def list_models(): - global model_path - global model_url - global model_name - global user_dir checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path,model_url=model_url,command_path= user_dir, - ext_filter=[".ckpt"], download_name=model_name) - print(f"Model list: {model_list}") - model_dir = os.path.abspath(model_path) + model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name) - def modeltitle(path, h): + def modeltitle(path, shorthash): abspath = os.path.abspath(path) - if abspath.startswith(model_dir): - name = abspath.replace(model_dir, '') + if user_dir is not None and abspath.startswith(user_dir): + name = abspath.replace(user_dir, '') + elif abspath.startswith(model_path): + name = abspath.replace(model_path, '') else: name = os.path.basename(path) @@ -85,29 +62,30 @@ def list_models(): shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] - return f'{name} [{h}]', shortname + return f'{name} [{shorthash}]', shortname cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) - title, model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name) + title, short_model_name = modeltitle(cmd_ckpt, h) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) for filename in model_list: h = model_hash(filename) - title, model_name = modeltitle(filename, h) - checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name) + title, short_model_name = modeltitle(filename, h) + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) + def get_closet_checkpoint_match(searchString): applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title)) - if len(applicable)>0: + if len(applicable) > 0: return applicable[0] return None + def model_hash(filename): try: - print(f"Opening: {filename}") with open(filename, "rb") as file: import hashlib m = hashlib.sha256() @@ -128,7 +106,7 @@ def select_checkpoint(): if len(checkpoints_list) == 0: print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) - print(f" - directory {os.path.abspath(shared.cmd_opts.stablediffusion_models_path)}", file=sys.stderr) + print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) exit(1) diff --git a/modules/shared.py b/modules/shared.py index 69002158..03a1a4d3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -21,8 +21,7 @@ model_path = os.path.join(script_path, 'models') parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) -# This should be deprecated, but we'll leave it for a few iterations -parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints (Deprecated, use '--stablediffusion-models-path'", ) +parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") @@ -41,7 +40,6 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) -parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR')) parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") @@ -61,10 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) cmd_opts = parser.parse_args() -if cmd_opts.ckpt_dir is not None: - print("The 'ckpt-dir' arg is deprecated in favor of the 'stablediffusion-models-path' argument and will be " - "removed in a future release. Please use the new option if you wish to use a custom checkpoint directory.") - cmd_opts.__setattr__("stablediffusion-models-path", cmd_opts.ckpt_dir) device = get_optimal_device() batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) diff --git a/webui.py b/webui.py index 5fd65edc..b8cccd54 100644 --- a/webui.py +++ b/webui.py @@ -28,7 +28,7 @@ from modules.paths import script_path from modules.shared import cmd_opts modelloader.cleanup_models() -modules.sd_models.setup_model(cmd_opts.stablediffusion_models_path) +modules.sd_models.setup_model(cmd_opts.ckpt_dir) codeformer.setup_model(cmd_opts.codeformer_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path) shared.face_restorers.append(modules.face_restoration.FaceRestoration()) -- cgit v1.2.3 From b60cd0809f6dc99ff0751593f562b64609b3117c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 30 Sep 2022 12:56:36 +0300 Subject: return the dropdown that mysteriously disappeared --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 03a1a4d3..8428c7a3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -193,7 +193,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), -- cgit v1.2.3 From 2b03f0bbda1229dff6e7ab6f656b28587eba8308 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 30 Sep 2022 22:16:03 +0300 Subject: if --ckpt option is specified, load that model --- modules/sd_models.py | 1 + modules/shared.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ab014efb..2539f14c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -69,6 +69,7 @@ def list_models(): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) + shared.opts.sd_model_checkpoint = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) for filename in model_list: diff --git a/modules/shared.py b/modules/shared.py index 8428c7a3..ac968b2d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -20,7 +20,7 @@ default_sd_model_file = sd_model_file model_path = os.path.join(script_path, 'models') parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) -parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) +parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints") parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) -- cgit v1.2.3 From 820f1dc96b1979d7e92170c161db281ee8bd988b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 15:03:39 +0300 Subject: initial support for training textual inversion --- .gitignore | 1 + javascript/progressbar.js | 1 + javascript/textualInversion.js | 8 + modules/devices.py | 3 +- modules/processing.py | 13 +- modules/sd_hijack.py | 324 ++++------------------ modules/sd_hijack_optimizations.py | 164 +++++++++++ modules/sd_models.py | 4 +- modules/shared.py | 3 +- modules/textual_inversion/dataset.py | 76 +++++ modules/textual_inversion/textual_inversion.py | 258 +++++++++++++++++ modules/textual_inversion/ui.py | 32 +++ modules/ui.py | 139 ++++++++-- style.css | 10 +- textual_inversion_templates/style.txt | 19 ++ textual_inversion_templates/style_filewords.txt | 19 ++ textual_inversion_templates/subject.txt | 27 ++ textual_inversion_templates/subject_filewords.txt | 27 ++ webui.py | 15 +- 19 files changed, 828 insertions(+), 315 deletions(-) create mode 100644 javascript/textualInversion.js create mode 100644 modules/sd_hijack_optimizations.py create mode 100644 modules/textual_inversion/dataset.py create mode 100644 modules/textual_inversion/textual_inversion.py create mode 100644 modules/textual_inversion/ui.py create mode 100644 textual_inversion_templates/style.txt create mode 100644 textual_inversion_templates/style_filewords.txt create mode 100644 textual_inversion_templates/subject.txt create mode 100644 textual_inversion_templates/subject_filewords.txt (limited to 'modules/shared.py') diff --git a/.gitignore b/.gitignore index 3532dab3..7afc9395 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ __pycache__ /.idea notification.mp3 /SwinIR +/textual_inversion diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 21f25b38..1e297abb 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte onUiUpdate(function(){ check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') + check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery') }) function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js new file mode 100644 index 00000000..8061be08 --- /dev/null +++ b/javascript/textualInversion.js @@ -0,0 +1,8 @@ + + +function start_training_textual_inversion(){ + requestProgress('ti') + gradioApp().querySelector('#ti_error').innerHTML='' + + return args_to_array(arguments) +} diff --git a/modules/devices.py b/modules/devices.py index 07bb2339..ff82f2f6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -32,10 +32,9 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") - device = get_optimal_device() device_codeformer = cpu if has_mps else device - +dtype = torch.float16 def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. diff --git a/modules/processing.py b/modules/processing.py index 7eeb5191..8223423a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -56,7 +56,7 @@ class StableDiffusionProcessing: self.prompt: str = prompt self.prompt_for_display: str = None self.negative_prompt: str = (negative_prompt or "") - self.styles: str = styles + self.styles: list = styles or [] self.seed: int = seed self.subseed: int = subseed self.subseed_strength: float = subseed_strength @@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), - "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta), + "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), } generation_params.update(p.extra_generation_params) @@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: fix_seed(p) - os.makedirs(p.outpath_samples, exist_ok=True) - os.makedirs(p.outpath_grids, exist_ok=True) + if p.outpath_samples is not None: + os.makedirs(p.outpath_samples, exist_ok=True) + + if p.outpath_grids is not None: + os.makedirs(p.outpath_grids, exist_ok=True) modules.sd_hijack.model_hijack.apply_circular(p.tiling) @@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) if os.path.exists(cmd_opts.embeddings_dir): - model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) + model_hijack.embedding_db.load_textual_inversion_embeddings() infotexts = [] output_images = [] diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fa7eaeb8..fd57e5c5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,244 +6,41 @@ import torch import numpy as np from torch import einsum -from modules import prompt_parser +import modules.textual_inversion.textual_inversion +from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts -from ldm.util import default -from einops import rearrange import ldm.modules.attention import ldm.modules.diffusionmodules.model +attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward +diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity +diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward -# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): - h = self.heads - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - del context, x +def apply_optimizations(): + if cmd_opts.opt_split_attention_v1: + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale +def undo_optimizations(): + ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward + ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward - s2 = s1.softmax(dim=-1) - del s1 - - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - - -# taken from https://github.com/Doggettx/stable-diffusion -def split_cross_attention_forward(self, x, context=None, mask=None): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - - del q, k, v - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - -def cross_attention_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q1.shape - - q2 = q1.reshape(b, c, h*w) - del q1 - - q = q2.permute(0, 2, 1) # b,hw,c - del q2 - - k = k1.reshape(b, c, h*w) # b,c,hw - del k1 - - h_ = torch.zeros_like(k, device=q.device) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c)**(-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) - del w2 - - # attend to values - v1 = v.reshape(b, c, h*w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 - - h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 - - h2 = h_.reshape(b, c, h, w) - del h_ - - h3 = self.proj_out(h2) - del h2 - - h3 += x - - return h3 class StableDiffusionModelHijack: - ids_lookup = {} - word_embeddings = {} - word_embeddings_checksums = {} fixes = None comments = [] - dir_mtime = None layers = None circular_enabled = False clip = None - def load_textual_inversion_embeddings(self, dirname, model): - mt = os.path.getmtime(dirname) - if self.dir_mtime is not None and mt <= self.dir_mtime: - return - - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() - - tokenizer = model.cond_stage_model.tokenizer - - def const_hash(a): - r = 0 - for v in a: - r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF - return r - - def process_file(path, filename): - name = os.path.splitext(filename)[0] - - data = torch.load(path, map_location="cpu") - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - self.word_embeddings[name] = emb.detach().to(device) - self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}' - - ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] - - first_id = ids[0] - if first_id not in self.ids_lookup: - self.ids_lookup[first_id] = [] - self.ids_lookup[first_id].append((ids, name)) - - for fn in os.listdir(dirname): - try: - fullfn = os.path.join(dirname, fn) - - if os.stat(fullfn).st_size == 0: - continue - - process_file(fullfn, fn) - except Exception: - print(f"Error loading emedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue - - print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): model_embeddings = m.cond_stage_model.transformer.text_model.embeddings @@ -253,12 +50,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model - if cmd_opts.opt_split_attention_v1: - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack - ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + apply_optimizations() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.hijack = hijack + self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length self.token_mults = {} @@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - if possible_matches is None: + if embedding is None: remade_tokens.append(token) multipliers.append(weight) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i + len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: mult *= mult_change - elif possible_matches is None: + i += 1 + elif embedding is None: remade_tokens.append(token) multipliers.append(mult) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i+len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(mult) - - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] @@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - self.hijack.fixes = hijack_fixes self.hijack.comments = hijack_comments @@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module): inputs_embeds = self.wrapped(input_ids) - if batch_fixes is not None: - for fixes, tensor in zip(batch_fixes, inputs_embeds): - for offset, word in fixes: - emb = self.embeddings.word_embeddings[word] - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len] + if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: + return inputs_embeds + + vecs = [] + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, embedding in fixes: + emb = embedding.vec + emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + + vecs.append(tensor) - return inputs_embeds + return torch.stack(vecs) def add_circular_option_to_conv_2d(): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py new file mode 100644 index 00000000..9c079e57 --- /dev/null +++ b/modules/sd_hijack_optimizations.py @@ -0,0 +1,164 @@ +import math +import torch +from torch import einsum + +from ldm.util import default +from einops import rearrange + + +# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion +def split_cross_attention_forward_v1(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +# taken from https://github.com/Doggettx/stable-diffusion +def split_cross_attention_forward(self, x, context=None, mask=None): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) * self.scale + v_in = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + +def nonlinearity_hijack(x): + # swish + t = torch.sigmoid(x) + x *= t + del t + + return x + +def cross_attention_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_) + k1 = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 diff --git a/modules/sd_models.py b/modules/sd_models.py index 2539f14c..5b3dbdc7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader +from modules import shared, modelloader, devices from modules.paths import models_path model_dir = "Stable-diffusion" @@ -134,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): if not shared.cmd_opts.no_half: model.half() + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + model.sd_model_hash = sd_model_hash model.sd_model_checkpint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index ac968b2d..ac0bc480 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -78,6 +78,7 @@ class State: current_latent = None current_image = None current_image_sampling_step = 0 + textinfo = None def interrupt(self): self.interrupted = True @@ -88,7 +89,7 @@ class State: self.current_image_sampling_step = 0 def get_job_timestamp(self): - return datetime.datetime.now().strftime("%Y%m%d%H%M%S") + return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? state = State() diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py new file mode 100644 index 00000000..7e134a08 --- /dev/null +++ b/modules/textual_inversion/dataset.py @@ -0,0 +1,76 @@ +import os +import numpy as np +import PIL +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random +import tqdm + + +class PersonalizedBase(Dataset): + def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None): + + self.placeholder_token = placeholder_token + + self.size = size + self.width = width + self.height = height + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + self.dataset = [] + + with open(template_file, "r") as file: + lines = [x.strip() for x in file.readlines()] + + self.lines = lines + + assert data_root, 'dataset directory not specified' + + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + print("Preparing dataset...") + for path in tqdm.tqdm(self.image_paths): + image = Image.open(path) + image = image.convert('RGB') + image = image.resize((self.width, self.height), PIL.Image.BICUBIC) + + filename = os.path.basename(path) + filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-') + filename_tokens = [token for token in filename_tokens if token.isalpha()] + + npimage = np.array(image).astype(np.uint8) + npimage = (npimage / 127.5 - 1.0).astype(np.float32) + + torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) + torchdata = torch.moveaxis(torchdata, 2, 0) + + init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() + + self.dataset.append((init_latent, filename_tokens)) + + self.length = len(self.dataset) * repeats + + self.initial_indexes = np.arange(self.length) % len(self.dataset) + self.indexes = None + self.shuffle() + + def shuffle(self): + self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + + def __len__(self): + return self.length + + def __getitem__(self, i): + if i % len(self.dataset) == 0: + self.shuffle() + + index = self.indexes[i % len(self.indexes)] + x, filename_tokens = self.dataset[index] + + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + text = text.replace("[filewords]", ' '.join(filename_tokens)) + + return x, text diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py new file mode 100644 index 00000000..c0baaace --- /dev/null +++ b/modules/textual_inversion/textual_inversion.py @@ -0,0 +1,258 @@ +import os +import sys +import traceback + +import torch +import tqdm +import html +import datetime + +from modules import shared, devices, sd_hijack, processing +import modules.textual_inversion.dataset + + +class Embedding: + def __init__(self, vec, name, step=None): + self.vec = vec + self.name = name + self.step = step + self.cached_checksum = None + + def save(self, filename): + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + } + + torch.save(embedding_data, filename) + + def checksum(self): + if self.cached_checksum is not None: + return self.cached_checksum + + def const_hash(a): + r = 0 + for v in a: + r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF + return r + + self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' + return self.cached_checksum + +class EmbeddingDatabase: + def __init__(self, embeddings_dir): + self.ids_lookup = {} + self.word_embeddings = {} + self.dir_mtime = None + self.embeddings_dir = embeddings_dir + + def register_embedding(self, embedding, model): + + self.word_embeddings[embedding.name] = embedding + + ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0] + + first_id = ids[0] + if first_id not in self.ids_lookup: + self.ids_lookup[first_id] = [] + self.ids_lookup[first_id].append((ids, embedding)) + + return embedding + + def load_textual_inversion_embeddings(self): + mt = os.path.getmtime(self.embeddings_dir) + if self.dir_mtime is not None and mt <= self.dir_mtime: + return + + self.dir_mtime = mt + self.ids_lookup.clear() + self.word_embeddings.clear() + + def process_file(path, filename): + name = os.path.splitext(filename)[0] + + data = torch.load(path, map_location="cpu") + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + self.register_embedding(embedding, shared.sd_model) + + for fn in os.listdir(self.embeddings_dir): + try: + fullfn = os.path.join(self.embeddings_dir, fn) + + if os.stat(fullfn).st_size == 0: + continue + + process_file(fullfn, fn) + except Exception: + print(f"Error loading emedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + continue + + print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + + def find_embedding_at_position(self, tokens, offset): + token = tokens[offset] + possible_matches = self.ids_lookup.get(token, None) + + if possible_matches is None: + return None + + for ids, embedding in possible_matches: + if tokens[offset:offset + len(ids)] == ids: + return embedding + + return None + + + +def create_embedding(name, num_vectors_per_token): + init_text = '*' + + cond_model = shared.sd_model.cond_stage_model + embedding_layer = cond_model.wrapped.transformer.text_model.embeddings + + ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer(ids.to(devices.device)).squeeze(0) + vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) + + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + + fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + embedding = Embedding(vec, name) + embedding.step = 0 + embedding.save(fn) + + return fn + + +def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): + assert embedding_name, 'embedding not selected' + + shared.state.textinfo = "Initializing textual inversion training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name) + + if save_embedding_every > 0: + embedding_dir = os.path.join(log_directory, "embeddings") + os.makedirs(embedding_dir, exist_ok=True) + else: + embedding_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + cond_model = shared.sd_model.cond_stage_model + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + + hijack = sd_hijack.model_hijack + + embedding = hijack.embedding_db.word_embeddings[embedding_name] + embedding.vec.requires_grad = True + + optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = embedding.step or 0 + if ititial_step > steps: + return embedding, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + for i, (x, text) in pbar: + embedding.step = i + ititial_step + + if embedding.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([text]) + loss = shared.sd_model(x.unsqueeze(0), c)[0] + + losses[embedding.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: + last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') + embedding.save(last_saved_file) + + if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {text}" + + shared.state.job_no = embedding.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {embedding.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + embedding.cached_checksum = None + embedding.save(filename) + + return embedding, filename + diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py new file mode 100644 index 00000000..ce3677a9 --- /dev/null +++ b/modules/textual_inversion/ui.py @@ -0,0 +1,32 @@ +import html + +import gradio as gr + +import modules.textual_inversion.textual_inversion as ti +from modules import sd_hijack, shared + + +def create_embedding(name, nvpt): + filename = ti.create_embedding(name, nvpt) + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" + + +def train_embedding(*args): + + try: + sd_hijack.undo_optimizations() + + embedding, filename = ti.train_embedding(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps. +Embedding saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + sd_hijack.apply_optimizations() diff --git a/modules/ui.py b/modules/ui.py index 15572bb0..57aef6ff 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,6 +21,7 @@ import gradio as gr import gradio.utils import gradio.routes +from modules import sd_hijack from modules.paths import script_path from modules.shared import opts, cmd_opts import modules.shared as shared @@ -32,6 +33,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste +import modules.textual_inversion.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -142,8 +144,8 @@ def save_files(js_data, images, index): return '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def wrap_gradio_call(func): - def f(*args, **kwargs): +def wrap_gradio_call(func, extra_outputs=None): + def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled if run_memmon: shared.mem_mon.monitor() @@ -159,7 +161,10 @@ def wrap_gradio_call(func): shared.state.job = "" shared.state.job_count = 0 - res = [None, '', f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + if extra_outputs_array is None: + extra_outputs_array = [None, ''] + + res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] elapsed = time.perf_counter() - t @@ -179,6 +184,7 @@ def wrap_gradio_call(func): res[-1] += f"

Time taken: {elapsed:.2f}s

{vram_html}
" shared.state.interrupted = False + shared.state.job_count = 0 return tuple(res) @@ -187,7 +193,7 @@ def wrap_gradio_call(func): def check_progress_call(id_part): if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False) + return "", gr_show(False), gr_show(False), gr_show(False) progress = 0 @@ -219,13 +225,19 @@ def check_progress_call(id_part): else: preview_visibility = gr_show(True) - return f"

{progressbar}

", preview_visibility, image + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result def check_progress_call_initial(id_part): shared.state.job_count = -1 shared.state.current_latent = None shared.state.current_image = None + shared.state.textinfo = None return check_progress_call(id_part) @@ -399,13 +411,16 @@ def create_toprow(is_img2img): return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste -def setup_progressbar(progressbar, preview, id_part): +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) check_progress.click( fn=lambda: check_progress_call(id_part), show_progress=False, inputs=[], - outputs=[progressbar, preview, preview], + outputs=[progressbar, preview, preview, textinfo], ) check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) @@ -413,11 +428,14 @@ def setup_progressbar(progressbar, preview, id_part): fn=lambda: check_progress_call_initial(id_part), show_progress=False, inputs=[], - outputs=[progressbar, preview, preview], + outputs=[progressbar, preview, preview, textinfo], ) -def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): +def create_ui(wrap_gradio_gpu_call): + import modules.img2img + import modules.txt2img + with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) @@ -483,7 +501,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) txt2img_args = dict( - fn=txt2img, + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), _js="submit", inputs=[ txt2img_prompt, @@ -675,7 +693,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ) img2img_args = dict( - fn=img2img, + fn=wrap_gradio_gpu_call(modules.img2img.img2img), _js="submit_img2img", inputs=[ dummy_component, @@ -828,7 +846,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): open_extras_folder = gr.Button('Open output directory', elem_id=button_id) submit.click( - fn=run_extras, + fn=wrap_gradio_gpu_call(modules.extras.run_extras), _js="get_extras_tab_index", inputs=[ dummy_component, @@ -878,7 +896,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): pnginfo_send_to_img2img = gr.Button('Send to img2img') image.change( - fn=wrap_gradio_call(run_pnginfo), + fn=wrap_gradio_call(modules.extras.run_pnginfo), inputs=[image], outputs=[html, generation_info, html2], ) @@ -887,7 +905,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - + with gr.Row(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") @@ -896,10 +914,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Safe as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - + with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + with gr.Blocks() as textual_inversion_interface: + with gr.Row().style(equal_height=False): + with gr.Column(): + with gr.Group(): + gr.HTML(value="

Create a new embedding

") + + new_embedding_name = gr.Textbox(label="Name") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create", variant='primary') + + with gr.Group(): + gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") + train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + learn_rate = gr.Number(label='Learning rate', value=5.0e-03) + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") + template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + steps = gr.Number(label='Max steps', value=100000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0) + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0) + + with gr.Row(): + with gr.Column(scale=2): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_training = gr.Button(value="Interrupt") + train_embedding = gr.Button(value="Train", variant='primary') + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + nvpt, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + learn_rate, + dataset_directory, + log_directory, + steps, + create_image_every, + save_embedding_every, + template_file, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + def create_setting_component(key): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -1011,6 +1115,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (textual_inversion_interface, "Textual inversion", "ti"), (settings_interface, "Settings", "settings"), ] @@ -1044,11 +1149,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def modelmerger(*args): try: - results = run_modelmerger(*args) + results = modules.extras.run_modelmerger(*args) except Exception as e: print("Error loading/saving model file:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() #To remove the potentially missing models from the list + modules.sd_models.list_models() # to remove the potentially missing models from the list return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] return results diff --git a/style.css b/style.css index 79d6bb0d..39586bf1 100644 --- a/style.css +++ b/style.css @@ -157,7 +157,7 @@ button{ max-width: 10em; } -#txt2img_preview, #img2img_preview{ +#txt2img_preview, #img2img_preview, #ti_preview{ position: absolute; width: 320px; left: 0; @@ -172,18 +172,18 @@ button{ } @media screen and (min-width: 768px) { - #txt2img_preview, #img2img_preview { + #txt2img_preview, #img2img_preview, #ti_preview { position: absolute; } } @media screen and (max-width: 767px) { - #txt2img_preview, #img2img_preview { + #txt2img_preview, #img2img_preview, #ti_preview { position: relative; } } -#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{ +#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{ display: none; } @@ -247,7 +247,7 @@ input[type="range"]{ #txt2img_negative_prompt, #img2img_negative_prompt{ } -#txt2img_progressbar, #img2img_progressbar{ +#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ position: absolute; z-index: 1000; right: 0; diff --git a/textual_inversion_templates/style.txt b/textual_inversion_templates/style.txt new file mode 100644 index 00000000..15af2d6b --- /dev/null +++ b/textual_inversion_templates/style.txt @@ -0,0 +1,19 @@ +a painting, art by [name] +a rendering, art by [name] +a cropped painting, art by [name] +the painting, art by [name] +a clean painting, art by [name] +a dirty painting, art by [name] +a dark painting, art by [name] +a picture, art by [name] +a cool painting, art by [name] +a close-up painting, art by [name] +a bright painting, art by [name] +a cropped painting, art by [name] +a good painting, art by [name] +a close-up painting, art by [name] +a rendition, art by [name] +a nice painting, art by [name] +a small painting, art by [name] +a weird painting, art by [name] +a large painting, art by [name] diff --git a/textual_inversion_templates/style_filewords.txt b/textual_inversion_templates/style_filewords.txt new file mode 100644 index 00000000..b3a8159a --- /dev/null +++ b/textual_inversion_templates/style_filewords.txt @@ -0,0 +1,19 @@ +a painting of [filewords], art by [name] +a rendering of [filewords], art by [name] +a cropped painting of [filewords], art by [name] +the painting of [filewords], art by [name] +a clean painting of [filewords], art by [name] +a dirty painting of [filewords], art by [name] +a dark painting of [filewords], art by [name] +a picture of [filewords], art by [name] +a cool painting of [filewords], art by [name] +a close-up painting of [filewords], art by [name] +a bright painting of [filewords], art by [name] +a cropped painting of [filewords], art by [name] +a good painting of [filewords], art by [name] +a close-up painting of [filewords], art by [name] +a rendition of [filewords], art by [name] +a nice painting of [filewords], art by [name] +a small painting of [filewords], art by [name] +a weird painting of [filewords], art by [name] +a large painting of [filewords], art by [name] diff --git a/textual_inversion_templates/subject.txt b/textual_inversion_templates/subject.txt new file mode 100644 index 00000000..79f36aa0 --- /dev/null +++ b/textual_inversion_templates/subject.txt @@ -0,0 +1,27 @@ +a photo of a [name] +a rendering of a [name] +a cropped photo of the [name] +the photo of a [name] +a photo of a clean [name] +a photo of a dirty [name] +a dark photo of the [name] +a photo of my [name] +a photo of the cool [name] +a close-up photo of a [name] +a bright photo of the [name] +a cropped photo of a [name] +a photo of the [name] +a good photo of the [name] +a photo of one [name] +a close-up photo of the [name] +a rendition of the [name] +a photo of the clean [name] +a rendition of a [name] +a photo of a nice [name] +a good photo of a [name] +a photo of the nice [name] +a photo of the small [name] +a photo of the weird [name] +a photo of the large [name] +a photo of a cool [name] +a photo of a small [name] diff --git a/textual_inversion_templates/subject_filewords.txt b/textual_inversion_templates/subject_filewords.txt new file mode 100644 index 00000000..008652a6 --- /dev/null +++ b/textual_inversion_templates/subject_filewords.txt @@ -0,0 +1,27 @@ +a photo of a [name], [filewords] +a rendering of a [name], [filewords] +a cropped photo of the [name], [filewords] +the photo of a [name], [filewords] +a photo of a clean [name], [filewords] +a photo of a dirty [name], [filewords] +a dark photo of the [name], [filewords] +a photo of my [name], [filewords] +a photo of the cool [name], [filewords] +a close-up photo of a [name], [filewords] +a bright photo of the [name], [filewords] +a cropped photo of a [name], [filewords] +a photo of the [name], [filewords] +a good photo of the [name], [filewords] +a photo of one [name], [filewords] +a close-up photo of the [name], [filewords] +a rendition of the [name], [filewords] +a photo of the clean [name], [filewords] +a rendition of a [name], [filewords] +a photo of a nice [name], [filewords] +a good photo of a [name], [filewords] +a photo of the nice [name], [filewords] +a photo of the small [name], [filewords] +a photo of the weird [name], [filewords] +a photo of the large [name], [filewords] +a photo of a cool [name], [filewords] +a photo of a small [name], [filewords] diff --git a/webui.py b/webui.py index b8cccd54..19fdcdd4 100644 --- a/webui.py +++ b/webui.py @@ -12,7 +12,6 @@ import modules.bsrgan_model as bsrgan import modules.extras import modules.face_restoration import modules.gfpgan_model as gfpgan -import modules.img2img import modules.ldsr_model as ldsr import modules.lowvram import modules.realesrgan_model as realesrgan @@ -21,7 +20,6 @@ import modules.sd_hijack import modules.sd_models import modules.shared as shared import modules.swinir_model as swinir -import modules.txt2img import modules.ui from modules import modelloader from modules.paths import script_path @@ -46,7 +44,7 @@ def wrap_queued_call(func): return f -def wrap_gradio_gpu_call(func): +def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): devices.torch_gc() @@ -58,6 +56,7 @@ def wrap_gradio_gpu_call(func): shared.state.current_image = None shared.state.current_image_sampling_step = 0 shared.state.interrupted = False + shared.state.textinfo = None with queue_lock: res = func(*args, **kwargs) @@ -69,7 +68,7 @@ def wrap_gradio_gpu_call(func): return res - return modules.ui.wrap_gradio_call(f) + return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) modules.scripts.load_scripts(os.path.join(script_path, "scripts")) @@ -86,13 +85,7 @@ def webui(): signal.signal(signal.SIGINT, sigint_handler) - demo = modules.ui.create_ui( - txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img), - img2img=wrap_gradio_gpu_call(modules.img2img.img2img), - run_extras=wrap_gradio_gpu_call(modules.extras.run_extras), - run_pnginfo=modules.extras.run_pnginfo, - run_modelmerger=modules.extras.run_modelmerger - ) + demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) demo.launch( share=cmd_opts.share, -- cgit v1.2.3