diff options
Diffstat (limited to 'extensions-builtin')
-rw-r--r-- | extensions-builtin/LDSR/ldsr_model_arch.py | 13 | ||||
-rw-r--r-- | extensions-builtin/LDSR/scripts/ldsr_model.py | 3 | ||||
-rw-r--r-- | extensions-builtin/LDSR/sd_hijack_autoencoder.py | 28 | ||||
-rw-r--r-- | extensions-builtin/LDSR/sd_hijack_ddpm_v1.py | 66 | ||||
-rw-r--r-- | extensions-builtin/Lora/lora.py | 15 | ||||
-rw-r--r-- | extensions-builtin/Lora/scripts/lora_script.py | 2 | ||||
-rw-r--r-- | extensions-builtin/ScuNET/scripts/scunet_model.py | 16 | ||||
-rw-r--r-- | extensions-builtin/ScuNET/scunet_model_arch.py | 11 | ||||
-rw-r--r-- | extensions-builtin/SwinIR/scripts/swinir_model.py | 7 | ||||
-rw-r--r-- | extensions-builtin/SwinIR/swinir_model_arch.py | 6 | ||||
-rw-r--r-- | extensions-builtin/SwinIR/swinir_model_arch_v2.py | 58 | ||||
-rw-r--r-- | extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js | 52 |
12 files changed, 144 insertions, 133 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index bc11cc6e..7f450086 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -88,7 +88,7 @@ class LDSR: x_t = None logs = None - for n in range(n_runs): + for _ in range(n_runs): if custom_shape is not None: x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) @@ -110,7 +110,6 @@ class LDSR: diffusion_steps = int(steps) eta = 1.0 - down_sample_method = 'Lanczos' gc.collect() if torch.cuda.is_available: @@ -131,11 +130,11 @@ class LDSR: im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) else: print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") - + # pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge')) - + logs = self.run(model["model"], im_padded, diffusion_steps, eta) sample = logs["sample"] @@ -158,7 +157,7 @@ class LDSR: def get_cond(selected_path): - example = dict() + example = {} up_f = 4 c = selected_path.convert('RGB') c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) @@ -196,7 +195,7 @@ def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_s @torch.no_grad() def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): - log = dict() + log = {} z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, return_first_stage_outputs=True, @@ -244,7 +243,7 @@ def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) log["sample_noquant"] = x_sample_noquant log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) - except: + except Exception: pass log["sample"] = x_sample diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index da19cff1..fbbe9005 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -7,7 +7,8 @@ from basicsr.utils.download_util import load_file_from_url from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR from modules import shared, script_callbacks -import sd_hijack_autoencoder, sd_hijack_ddpm_v1 +import sd_hijack_autoencoder # noqa: F401 +import sd_hijack_ddpm_v1 # noqa: F401 class UpscalerLDSR(Upscaler): diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 8e03c7f8..81c5101b 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -1,16 +1,21 @@ # The content of this file comes from the ldm/models/autoencoder.py file of the compvis/stable-diffusion repo # The VQModel & VQModelInterface were subsequently removed from ldm/models/autoencoder.py when we moved to the stability-ai/stablediffusion repo # As the LDSR upscaler relies on VQModel & VQModelInterface, the hijack aims to put them back into the ldm.models.autoencoder - +import numpy as np import torch import pytorch_lightning as pl import torch.nn.functional as F from contextlib import contextmanager + +from torch.optim.lr_scheduler import LambdaLR + +from ldm.modules.ema import LitEma from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.util import instantiate_from_config import ldm.models.autoencoder +from packaging import version class VQModel(pl.LightningModule): def __init__(self, @@ -19,7 +24,7 @@ class VQModel(pl.LightningModule): n_embed, embed_dim, ckpt_path=None, - ignore_keys=[], + ignore_keys=None, image_key="image", colorize_nlabels=None, monitor=None, @@ -57,7 +62,7 @@ class VQModel(pl.LightningModule): print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or []) self.scheduler_config = scheduler_config self.lr_g_factor = lr_g_factor @@ -76,11 +81,11 @@ class VQModel(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list()): + def init_from_ckpt(self, path, ignore_keys=None): sd = torch.load(path, map_location="cpu")["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] @@ -165,7 +170,7 @@ class VQModel(pl.LightningModule): def validation_step(self, batch, batch_idx): log_dict = self._validation_step(batch, batch_idx) with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + self._validation_step(batch, batch_idx, suffix="_ema") return log_dict def _validation_step(self, batch, batch_idx, suffix=""): @@ -232,7 +237,7 @@ class VQModel(pl.LightningModule): return self.decoder.conv_out.weight def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.image_key) x = x.to(self.device) if only_inputs: @@ -249,7 +254,8 @@ class VQModel(pl.LightningModule): if plot_ema: with self.ema_scope(): xrec_ema, _ = self(x) - if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + if x.shape[1] > 3: + xrec_ema = self.to_rgb(xrec_ema) log["reconstructions_ema"] = xrec_ema return log @@ -264,7 +270,7 @@ class VQModel(pl.LightningModule): class VQModelInterface(VQModel): def __init__(self, embed_dim, *args, **kwargs): - super().__init__(embed_dim=embed_dim, *args, **kwargs) + super().__init__(*args, embed_dim=embed_dim, **kwargs) self.embed_dim = embed_dim def encode(self, x): @@ -282,5 +288,5 @@ class VQModelInterface(VQModel): dec = self.decoder(quant) return dec -setattr(ldm.models.autoencoder, "VQModel", VQModel) -setattr(ldm.models.autoencoder, "VQModelInterface", VQModelInterface) +ldm.models.autoencoder.VQModel = VQModel +ldm.models.autoencoder.VQModelInterface = VQModelInterface diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 5c0488e5..631a08ef 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -48,7 +48,7 @@ class DDPMV1(pl.LightningModule): beta_schedule="linear", loss_type="l2", ckpt_path=None, - ignore_keys=[], + ignore_keys=None, load_only_unet=False, monitor="val/loss", use_ema=True, @@ -100,7 +100,7 @@ class DDPMV1(pl.LightningModule): if monitor is not None: self.monitor = monitor if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys or [], only_model=load_only_unet) self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) @@ -182,13 +182,13 @@ class DDPMV1(pl.LightningModule): if context is not None: print(f"{context}: Restored training weights") - def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + def init_from_ckpt(self, path, ignore_keys=None, only_model=False): sd = torch.load(path, map_location="cpu") if "state_dict" in list(sd.keys()): sd = sd["state_dict"] keys = list(sd.keys()) for k in keys: - for ik in ignore_keys: + for ik in ignore_keys or []: if k.startswith(ik): print("Deleting key {} from state_dict.".format(k)) del sd[k] @@ -375,7 +375,7 @@ class DDPMV1(pl.LightningModule): @torch.no_grad() def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): - log = dict() + log = {} x = self.get_input(batch, self.first_stage_key) N = min(x.shape[0], N) n_row = min(x.shape[0], n_row) @@ -383,7 +383,7 @@ class DDPMV1(pl.LightningModule): log["inputs"] = x # get diffusion row - diffusion_row = list() + diffusion_row = [] x_start = x[:n_row] for t in range(self.num_timesteps): @@ -444,13 +444,13 @@ class LatentDiffusionV1(DDPMV1): conditioning_key = None ckpt_path = kwargs.pop("ckpt_path", None) ignore_keys = kwargs.pop("ignore_keys", []) - super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + super().__init__(*args, conditioning_key=conditioning_key, **kwargs) self.concat_mode = concat_mode self.cond_stage_trainable = cond_stage_trainable self.cond_stage_key = cond_stage_key try: self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 - except: + except Exception: self.num_downs = 0 if not scale_by_std: self.scale_factor = scale_factor @@ -460,7 +460,7 @@ class LatentDiffusionV1(DDPMV1): self.instantiate_cond_stage(cond_stage_config) self.cond_stage_forward = cond_stage_forward self.clip_denoised = False - self.bbox_tokenizer = None + self.bbox_tokenizer = None self.restarted_from_ckpt = False if ckpt_path is not None: @@ -792,7 +792,7 @@ class LatentDiffusionV1(DDPMV1): z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) # 2. apply model loop over last dim - if isinstance(self.first_stage_model, VQModelInterface): + if isinstance(self.first_stage_model, VQModelInterface): output_list = [self.first_stage_model.decode(z[:, :, :, :, i], force_not_quantize=predict_cids or force_not_quantize) for i in range(z.shape[-1])] @@ -877,16 +877,6 @@ class LatentDiffusionV1(DDPMV1): c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) return self.p_losses(x, c, t, *args, **kwargs) - def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset - def rescale_bbox(bbox): - x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) - y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) - w = min(bbox[2] / crop_coordinates[2], 1 - x0) - h = min(bbox[3] / crop_coordinates[3], 1 - y0) - return x0, y0, w, h - - return [rescale_bbox(b) for b in bboxes] - def apply_model(self, x_noisy, t, cond, return_ids=False): if isinstance(cond, dict): @@ -900,7 +890,7 @@ class LatentDiffusionV1(DDPMV1): if hasattr(self, "split_input_params"): assert len(cond) == 1 # todo can only deal with one conditioning atm - assert not return_ids + assert not return_ids ks = self.split_input_params["ks"] # eg. (128, 128) stride = self.split_input_params["stride"] # eg. (64, 64) @@ -1126,7 +1116,7 @@ class LatentDiffusionV1(DDPMV1): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] @@ -1157,8 +1147,10 @@ class LatentDiffusionV1(DDPMV1): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(x0_partial) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) return img, intermediates @torch.no_grad() @@ -1205,8 +1197,10 @@ class LatentDiffusionV1(DDPMV1): if i % log_every_t == 0 or i == timesteps - 1: intermediates.append(img) - if callback: callback(i) - if img_callback: img_callback(img, i) + if callback: + callback(i) + if img_callback: + img_callback(img, i) if return_intermediates: return img, intermediates @@ -1221,7 +1215,7 @@ class LatentDiffusionV1(DDPMV1): if cond is not None: if isinstance(cond, dict): cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else - list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + [x[:batch_size] for x in cond[key]] for key in cond} else: cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] return self.p_sample_loop(cond, @@ -1253,7 +1247,7 @@ class LatentDiffusionV1(DDPMV1): use_ddim = ddim_steps is not None - log = dict() + log = {} z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, return_first_stage_outputs=True, force_c_encode=True, @@ -1280,7 +1274,7 @@ class LatentDiffusionV1(DDPMV1): if plot_diffusion_rows: # get diffusion row - diffusion_row = list() + diffusion_row = [] z_start = z[:n_row] for t in range(self.num_timesteps): if t % self.log_every_t == 0 or t == self.num_timesteps - 1: @@ -1322,7 +1316,7 @@ class LatentDiffusionV1(DDPMV1): if inpaint: # make a simple center square - b, h, w = z.shape[0], z.shape[2], z.shape[3] + h, w = z.shape[2], z.shape[3] mask = torch.ones(N, h, w).to(self.device) # zeros will be filled in mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. @@ -1424,10 +1418,10 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): # TODO: move all layout-specific hacks to this class def __init__(self, cond_stage_key, *args, **kwargs): assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' - super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + super().__init__(*args, cond_stage_key=cond_stage_key, **kwargs) def log_images(self, batch, N=8, *args, **kwargs): - logs = super().log_images(batch=batch, N=N, *args, **kwargs) + logs = super().log_images(*args, batch=batch, N=N, **kwargs) key = 'train' if self.training else 'validation' dset = self.trainer.datamodule.datasets[key] @@ -1443,7 +1437,7 @@ class Layout2ImgDiffusionV1(LatentDiffusionV1): logs['bbox_image'] = cond_img return logs -setattr(ldm.models.diffusion.ddpm, "DDPMV1", DDPMV1) -setattr(ldm.models.diffusion.ddpm, "LatentDiffusionV1", LatentDiffusionV1) -setattr(ldm.models.diffusion.ddpm, "DiffusionWrapperV1", DiffusionWrapperV1) -setattr(ldm.models.diffusion.ddpm, "Layout2ImgDiffusionV1", Layout2ImgDiffusionV1) +ldm.models.diffusion.ddpm.DDPMV1 = DDPMV1 +ldm.models.diffusion.ddpm.LatentDiffusionV1 = LatentDiffusionV1 +ldm.models.diffusion.ddpm.DiffusionWrapperV1 = DiffusionWrapperV1 +ldm.models.diffusion.ddpm.Layout2ImgDiffusionV1 = Layout2ImgDiffusionV1 diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index b5d0c98f..1308c48b 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -1,4 +1,3 @@ -import glob
import os
import re
import torch
@@ -177,7 +176,7 @@ def load_lora(name, filename): else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
- assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
+ raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")
with torch.no_grad():
module.weight.copy_(weight)
@@ -189,7 +188,7 @@ def load_lora(name, filename): elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
+ raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")
if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
@@ -207,7 +206,7 @@ def load_loras(names, multipliers=None): loaded_loras.clear()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
- if any([x is None for x in loras_on_disk]):
+ if any(x is None for x in loras_on_disk):
list_available_loras()
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
@@ -314,7 +313,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu print(f'failed to calculate lora weights for layer {lora_layer_name}')
- setattr(self, "lora_current_names", wanted_names)
+ self.lora_current_names = wanted_names
def lora_forward(module, input, original_forward):
@@ -348,8 +347,8 @@ def lora_forward(module, input, original_forward): def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
- setattr(self, "lora_current_names", ())
- setattr(self, "lora_weights_backup", None)
+ self.lora_current_names = ()
+ self.lora_weights_backup = None
def lora_Linear_forward(self, input):
@@ -428,7 +427,7 @@ def infotext_pasted(infotext, params): added = []
- for k, v in params.items():
+ for k in params:
if not k.startswith("AddNet Model "):
continue
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 060bda05..728e0b86 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -53,7 +53,7 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
- "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
+ "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
}))
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index c7fd5739..cc2cbc6a 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -10,10 +10,9 @@ from tqdm import tqdm from basicsr.utils.download_util import load_file_from_url import modules.upscaler -from modules import devices, modelloader +from modules import devices, modelloader, script_callbacks from scunet_model_arch import SCUNet as net from modules.shared import opts -from modules import images class UpscalerScuNET(modules.upscaler.Upscaler): @@ -133,8 +132,19 @@ class UpscalerScuNET(modules.upscaler.Upscaler): model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) model.load_state_dict(torch.load(filename), strict=True) model.eval() - for k, v in model.named_parameters(): + for _, v in model.named_parameters(): v.requires_grad = False model = model.to(device) return model + + +def on_ui_settings(): + import gradio as gr + from modules import shared + + shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) + shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) + + +script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py index 43ca8d36..b51a8806 100644 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ b/extensions-builtin/ScuNET/scunet_model_arch.py @@ -61,7 +61,9 @@ class WMSA(nn.Module): Returns: output: tensor shape [b h w c] """ - if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + if self.type != 'W': + x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) + x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) h_windows = x.size(1) w_windows = x.size(2) @@ -85,8 +87,9 @@ class WMSA(nn.Module): output = self.linear(output) output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), - dims=(1, 2)) + if self.type != 'W': + output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2)) + return output def relative_embedding(self): @@ -262,4 +265,4 @@ class SCUNet(nn.Module): nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0)
\ No newline at end of file + nn.init.constant_(m.weight, 1.0) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index e8783bca..0ba50487 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,4 +1,3 @@ -import contextlib import os import numpy as np @@ -8,7 +7,7 @@ from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import cmd_opts, opts, state +from modules.shared import opts, state from swinir_model_arch import SwinIR as net from swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -45,7 +44,7 @@ class UpscalerSwinIR(Upscaler): img = upscale(img, model) try: torch.cuda.empty_cache() - except: + except Exception: pass return img @@ -151,7 +150,7 @@ def inference(img, model, tile, tile_overlap, window_size, scale): for w_idx in w_idx_list: if state.interrupted or state.skipped: break - + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py index 863f42db..93b93274 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ b/extensions-builtin/SwinIR/swinir_model_arch.py @@ -644,7 +644,7 @@ class SwinIR(nn.Module): """ def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, @@ -805,7 +805,7 @@ class SwinIR(nn.Module): def forward(self, x): H, W = x.shape[2:] x = self.check_image_size(x) - + self.mean = self.mean.type_as(x) x = (x - self.mean) * self.img_range @@ -844,7 +844,7 @@ class SwinIR(nn.Module): H, W = self.patches_resolution flops += H * W * 3 * self.embed_dim * 9 flops += self.patch_embed.flops() - for i, layer in enumerate(self.layers): + for layer in self.layers: flops += layer.flops() flops += H * W * 3 * self.embed_dim * self.embed_dim flops += self.upsample.flops() diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py index 0e28ae6e..dad22cca 100644 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ b/extensions-builtin/SwinIR/swinir_model_arch_v2.py @@ -74,7 +74,7 @@ class WindowAttention(nn.Module): """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=[0, 0]):
+ pretrained_window_size=(0, 0)):
super().__init__()
self.dim = dim
@@ -241,7 +241,7 @@ class SwinTransformerBlock(nn.Module): attn_mask = None
self.register_buffer("attn_mask", attn_mask)
-
+
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
@@ -263,7 +263,7 @@ class SwinTransformerBlock(nn.Module): attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
- return attn_mask
+ return attn_mask
def forward(self, x, x_size):
H, W = x_size
@@ -288,7 +288,7 @@ class SwinTransformerBlock(nn.Module): attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
+
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
@@ -369,7 +369,7 @@ class PatchMerging(nn.Module): H, W = self.input_resolution
flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
flops += H * W * self.dim // 2
- return flops
+ return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
@@ -447,7 +447,7 @@ class BasicLayer(nn.Module): nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
-
+
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
@@ -492,7 +492,7 @@ class PatchEmbed(nn.Module): flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
- return flops
+ return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
@@ -531,7 +531,7 @@ class RSTB(nn.Module): num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
@@ -622,7 +622,7 @@ class Upsample(nn.Sequential): else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
-
+
class Upsample_hf(nn.Sequential):
"""Upsample module.
@@ -642,7 +642,7 @@ class Upsample_hf(nn.Sequential): m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
+ super(Upsample_hf, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
@@ -667,8 +667,8 @@ class UpsampleOneStep(nn.Sequential): H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
-
-
+
+
class Swin2SR(nn.Module):
r""" Swin2SR
@@ -698,8 +698,8 @@ class Swin2SR(nn.Module): """
def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
- window_size=7, mlp_ratio=4., qkv_bias=True,
+ embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
+ window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
@@ -764,7 +764,7 @@ class Swin2SR(nn.Module): num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
@@ -776,7 +776,7 @@ class Swin2SR(nn.Module): )
self.layers.append(layer)
-
+
if self.upsampler == 'pixelshuffle_hf':
self.layers_hf = nn.ModuleList()
for i_layer in range(self.num_layers):
@@ -787,7 +787,7 @@ class Swin2SR(nn.Module): num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
+ qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
@@ -799,7 +799,7 @@ class Swin2SR(nn.Module): )
self.layers_hf.append(layer)
-
+
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
@@ -829,10 +829,10 @@ class Swin2SR(nn.Module): self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.conv_after_aux = nn.Sequential(
nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
+ nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
+
elif self.upsampler == 'pixelshuffle_hf':
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
|