diff options
35 files changed, 521 insertions, 114 deletions
diff --git a/modules/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 90e0a2f0..90e0a2f0 100644 --- a/modules/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py diff --git a/extensions-builtin/LDSR/preload.py b/extensions-builtin/LDSR/preload.py new file mode 100644 index 00000000..d746007c --- /dev/null +++ b/extensions-builtin/LDSR/preload.py @@ -0,0 +1,6 @@ +import os
+from modules import paths
+
+
+def preload(parser):
+ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
diff --git a/modules/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index 8c4db44a..1cef29a4 100644 --- a/modules/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -5,8 +5,9 @@ import traceback from basicsr.utils.download_util import load_file_from_url from modules.upscaler import Upscaler, UpscalerData -from modules.ldsr_model_arch import LDSR -from modules import shared +from ldsr_model_arch import LDSR +from modules import shared, script_callbacks +import sd_hijack_autoencoder class UpscalerLDSR(Upscaler): @@ -52,3 +53,12 @@ class UpscalerLDSR(Upscaler): return img ddim_steps = shared.opts.ldsr_steps return ldsr.super_resolution(img, ddim_steps, self.scale) + + +def on_ui_settings(): + import gradio as gr + + shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling"))) + + +script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py new file mode 100644 index 00000000..8e03c7f8 --- /dev/null +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -0,0 +1,286 @@ +# The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo +# The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo +# As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder + +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.util import instantiate_from_config + +import ldm.models.autoencoder + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + +setattr(ldm.models.autoencoder, "VQModel", VQModel) +setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) diff --git a/extensions-builtin/ScuNET/preload.py b/extensions-builtin/ScuNET/preload.py new file mode 100644 index 00000000..f12c5b90 --- /dev/null +++ b/extensions-builtin/ScuNET/preload.py @@ -0,0 +1,6 @@ +import os
+from modules import paths
+
+
+def preload(parser):
+ parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
diff --git a/modules/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 52360241..e0fbf3a3 100644 --- a/modules/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url import modules.upscaler from modules import devices, modelloader -from modules.scunet_model_arch import SCUNet as net +from scunet_model_arch import SCUNet as net class UpscalerScuNET(modules.upscaler.Upscaler): @@ -49,7 +49,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): if model is None: return img - device = devices.device_scunet + device = devices.get_device_for('scunet') img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 @@ -66,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): return PIL.Image.fromarray(output, 'RGB') def load_model(self, path: str): - device = devices.device_scunet + device = devices.get_device_for('scunet') if "http" in path: filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, progress=True) diff --git a/modules/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py index 43ca8d36..43ca8d36 100644 --- a/modules/scunet_model_arch.py +++ b/extensions-builtin/ScuNET/scunet_model_arch.py diff --git a/extensions-builtin/SwinIR/preload.py b/extensions-builtin/SwinIR/preload.py new file mode 100644 index 00000000..567e44bc --- /dev/null +++ b/extensions-builtin/SwinIR/preload.py @@ -0,0 +1,6 @@ +import os
+from modules import paths
+
+
+def preload(parser):
+ parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
diff --git a/modules/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index facd262d..782769e2 100644 --- a/modules/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -7,15 +7,14 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm -from modules import modelloader, devices +from modules import modelloader, devices, script_callbacks, shared from modules.shared import cmd_opts, opts -from modules.swinir_model_arch import SwinIR as net -from modules.swinir_model_arch_v2 import Swin2SR as net2 +from swinir_model_arch import SwinIR as net +from swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData -precision_scope = ( - torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext -) + +device_swinir = devices.get_device_for('swinir') class UpscalerSwinIR(Upscaler): @@ -42,7 +41,7 @@ class UpscalerSwinIR(Upscaler): model = self.load_model(model_file) if model is None: return img - model = model.to(devices.device_swinir) + model = model.to(device_swinir, dtype=devices.dtype) img = upscale(img, model) try: torch.cuda.empty_cache() @@ -94,8 +93,6 @@ class UpscalerSwinIR(Upscaler): model.load_state_dict(pretrained_model[params], strict=True) else: model.load_state_dict(pretrained_model, strict=True) - if not cmd_opts.no_half: - model = model.half() return model @@ -111,8 +108,8 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_swinir) - with torch.no_grad(), precision_scope("cuda"): + img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) + with torch.no_grad(), devices.autocast(): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old @@ -139,8 +136,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): stride = tile - tile_overlap h_idx_list = list(range(0, h - tile, stride)) + [h - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) - W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) + E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) + W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: @@ -159,3 +156,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale): output = E.div_(W) return output + + +def on_ui_settings(): + import gradio as gr + + shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) + shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) + + +script_callbacks.on_ui_settings(on_ui_settings) diff --git a/modules/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py index 863f42db..863f42db 100644 --- a/modules/swinir_model_arch.py +++ b/extensions-builtin/SwinIR/swinir_model_arch.py diff --git a/modules/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py index 0e28ae6e..0e28ae6e 100644 --- a/modules/swinir_model_arch_v2.py +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py diff --git a/javascript/hints.js b/javascript/hints.js index ac417ff6..57db35be 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -94,6 +94,8 @@ titles = { "Add difference": "Result = A + (B - C) * M", "Learning rate": "how fast should the training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG: 0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.", + + "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc." } diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 43d1d1ce..d58737c4 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -92,14 +92,26 @@ function check_gallery(id_gallery){ if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { // automatically re-open previously selected index (if exists) activeElement = gradioApp().activeElement; + let scrollX = window.scrollX; + let scrollY = window.scrollY; galleryButtons[prevSelectedIndex].click(); showGalleryImage(); + // When the gallery button is clicked, it gains focus and scrolls itself into view + // We need to scroll back to the previous position + setTimeout(function (){ + window.scrollTo(scrollX, scrollY); + }, 50); + if(activeElement){ // i fought this for about an hour; i don't know why the focus is lost or why this helps recover it - // if somenoe has a better solution please by all means - setTimeout(function() { activeElement.focus() }, 1); + // if someone has a better solution please by all means + setTimeout(function (){ + activeElement.focus({ + preventScroll: true // Refocus the element that was focused before the gallery was opened without scrolling to it + }) + }, 1); } } }) diff --git a/modules/api/api.py b/modules/api/api.py index 1de3f98f..54ee7cb0 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -152,7 +152,10 @@ class Api: ) if populate.sampler_name: populate.sampler_index = None # prevent a warning later on - p = StableDiffusionProcessingImg2Img(**vars(populate)) + + args = vars(populate) + args.pop('include_init_images', None) # this is meant to be done by "exclude": True in model, but it's for a reason that I cannot determine. + p = StableDiffusionProcessingImg2Img(**args) imgs = [] for img in init_images: @@ -170,7 +173,7 @@ class Api: b64images = list(map(encode_pil_to_base64, processed.images)) - if (not img2imgreq.include_init_images): + if not img2imgreq.include_init_images: img2imgreq.init_images = None img2imgreq.mask = None diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 31ec7e17..dfc83357 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -21,7 +21,7 @@ class DeepDanbooru: files = modelloader.load_models( model_path=os.path.join(paths.models_path, "torch_deepdanbooru"), model_url='https://github.com/AUTOMATIC1111/TorchDeepDanbooru/releases/download/v1/model-resnet_custom_v3.pt', - ext_filter=".pt", + ext_filter=[".pt"], download_name='model-resnet_custom_v3.pt', ) diff --git a/modules/devices.py b/modules/devices.py index f00079c6..f8cffae1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -44,6 +44,15 @@ def get_optimal_device(): return cpu +def get_device_for(task): + from modules import shared + + if task in shared.cmd_opts.use_cpu: + return cpu + + return get_optimal_device() + + def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device(get_cuda_device_string()): @@ -53,37 +62,35 @@ def torch_gc(): def enable_tf32(): if torch.cuda.is_available(): + + # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't + # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 + if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]): + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True + errors.run(enable_tf32, "Enabling TF32") cpu = torch.device("cpu") -device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None +device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 def randn(seed, shape): - # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. - if device.type == 'mps': - generator = torch.Generator(device=cpu) - generator.manual_seed(seed) - noise = torch.randn(shape, generator=generator, device=cpu).to(device) - return noise - torch.manual_seed(seed) + if device.type == 'mps': + return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) def randn_without_seed(shape): - # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. if device.type == 'mps': - generator = torch.Generator(device=cpu) - noise = torch.randn(shape, generator=generator, device=cpu).to(device) - return noise |