From 7c128bbdac0da1767c239174e91af6f327845372 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:56:17 +0800 Subject: Add fp8 for sd unet --- extensions-builtin/Lora/network.py | 2 +- extensions-builtin/Lora/network_full.py | 4 ++-- extensions-builtin/Lora/network_glora.py | 10 +++++----- extensions-builtin/Lora/network_hada.py | 12 ++++++------ extensions-builtin/Lora/network_ia3.py | 2 +- extensions-builtin/Lora/network_lokr.py | 18 +++++++++--------- extensions-builtin/Lora/network_lora.py | 6 +++--- extensions-builtin/Lora/network_norm.py | 4 ++-- extensions-builtin/Lora/networks.py | 6 +++--- 9 files changed, 32 insertions(+), 32 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 6021fd8d..a62e5eff 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -137,7 +137,7 @@ class NetworkModule: def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown += self.bias.to(orig_weight.device, dtype=updown.dtype) updown = updown.reshape(output_shape) if len(output_shape) == 4: diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index bf6930e9..f221c95f 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule): def calc_updown(self, orig_weight): output_shape = self.weight.shape - updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.weight.to(orig_weight.device) if self.ex_bias is not None: - ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.ex_bias.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py index 492d4870..efe5c681 100644 --- a/extensions-builtin/Lora/network_glora.py +++ b/extensions-builtin/Lora/network_glora.py @@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule): self.w2b = weights.w["b2.weight"] def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] - updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 5fcb0695..d95a0fd1 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule): self.t2 = weights.w.get("hada_t2") def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] if self.t1 is not None: output_shape = [w1a.size(1), w1b.size(1)] - t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + t1 = self.t1.to(orig_weight.device) updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) output_shape += t1.shape[2:] else: @@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule): updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) if self.t2 is not None: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) else: updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 7edc4249..96faeaf3 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule): self.on_input = weights.w["on_input"].item() def calc_updown(self, orig_weight): - w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) + w = self.w.to(orig_weight.device) output_shape = [w.size(0), orig_weight.size(1)] if self.on_input: diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 340acdab..fcdaeafd 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule): def calc_updown(self, orig_weight): if self.w1 is not None: - w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = self.w1.to(orig_weight.device) else: - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) w1 = w1a @ w1b if self.w2 is not None: - w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = self.w2.to(orig_weight.device) elif self.t2 is None: - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = w2a @ w2b else: - t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + t2 = self.t2.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 26c0a72c..4cc40295 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule): return module def calc_updown(self, orig_weight): - up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + up = self.up_model.weight.to(orig_weight.device) + down = self.down_model.weight.to(orig_weight.device) output_shape = [up.size(0), down.size(1)] if self.mid_model is not None: # cp-decomposition - mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + mid = self.mid_model.weight.to(orig_weight.device) updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) output_shape += mid.shape[2:] else: diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index ce450158..d25afcbb 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule): def calc_updown(self, orig_weight): output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.w_norm.to(orig_weight.device) if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.b_norm.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4..8ea4ea60 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -381,12 +381,12 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight += updown + self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: - self.bias = torch.nn.Parameter(ex_bias) + self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias += ex_bias + self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 -- cgit v1.2.3 From 370a77f8e78e65a8a1339289d684cb43df142f70 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 21 Nov 2023 19:59:34 +0800 Subject: Option for using fp16 weight when apply lora --- extensions-builtin/Lora/networks.py | 16 ++++++++++++---- modules/initialize_util.py | 1 + modules/sd_models.py | 14 +++++++++++--- modules/shared_options.py | 1 + 4 files changed, 25 insertions(+), 7 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 0170dbfb..d22ed843 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -388,18 +388,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if module is not None and hasattr(self, 'weight'): try: with torch.no_grad(): - updown, ex_bias = module.calc_updown(self.weight) + if getattr(self, 'fp16_weight', None) is None: + weight = self.weight + bias = self.bias + else: + weight = self.fp16_weight.clone().to(self.weight.device) + bias = getattr(self, 'fp16_bias', None) + if bias is not None: + bias = bias.clone().to(self.bias.device) + updown, ex_bias = module.calc_updown(weight) - if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + if len(weight.shape) == 4 and weight.shape[1] == 9: # inpainting model. zero pad updown to make channel[1] 4 to 9 updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) - self.weight.copy_((self.weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) + self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype)) if ex_bias is not None and hasattr(self, 'bias'): if self.bias is None: self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype) else: - self.bias.copy_((self.bias.to(dtype=ex_bias.dtype) + ex_bias).to(dtype=self.bias.dtype)) + self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype)) except RuntimeError as e: logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 1b11ead6..7fb1d8d5 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -178,6 +178,7 @@ def configure_opts_onchange(): shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index eb491434..0a7777f1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -413,14 +413,22 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") + for module in model.modules(): + if hasattr(module, 'fp16_weight'): + del module.fp16_weight + if hasattr(module, 'fp16_bias'): + del module.fp16_bias + if check_fp8(model): devices.fp8 = True first_stage = model.first_stage_model model.first_stage_model = None for module in model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): + if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): + if shared.opts.cache_fp16_weight: + module.fp16_weight = module.weight.clone().half() + if module.bias is not None: + module.fp16_bias = module.bias.clone().half() module.to(torch.float8_e4m3fn) model.first_stage_model = first_stage timer.record("apply fp8") diff --git a/modules/shared_options.py b/modules/shared_options.py index d27f35e9..eaa9f135 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -201,6 +201,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), + "cache_fp16_weight": OptionInfo(False, "Cache FP16 weight for LoRA").info("Cache fp16 weight when enabling FP8, will increase the quality of LoRA. Use more system ram."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { -- cgit v1.2.3 From 22e23dbf29b0bbc807daa57318c31145f8dd0774 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 4 Dec 2023 15:56:03 +0300 Subject: add hypertile infotext --- .../hypertile/scripts/hypertile_script.py | 53 +++++++++++++++++----- 1 file changed, 42 insertions(+), 11 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/hypertile/scripts/hypertile_script.py b/extensions-builtin/hypertile/scripts/hypertile_script.py index d3ab6091..395d584b 100644 --- a/extensions-builtin/hypertile/scripts/hypertile_script.py +++ b/extensions-builtin/hypertile/scripts/hypertile_script.py @@ -17,11 +17,42 @@ class ScriptHypertile(scripts.Script): configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet) + self.add_infotext(p) + def before_hr(self, p, *args): + + enable = shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet + # exclusive hypertile seed for the second pass - if not shared.opts.hypertile_enable_unet: + if enable: hypertile.set_hypertile_seed(p.all_seeds[0]) - configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass) + + configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=enable) + + if enable and not shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net second pass"] = True + + self.add_infotext(p, add_unet_params=True) + + def add_infotext(self, p, add_unet_params=False): + def option(name): + value = getattr(shared.opts, name) + default_value = shared.opts.get_default(name) + return None if value == default_value else value + + if shared.opts.hypertile_enable_unet: + p.extra_generation_params["Hypertile U-Net"] = True + + if shared.opts.hypertile_enable_unet or add_unet_params: + p.extra_generation_params["Hypertile U-Net max depth"] = option('hypertile_max_depth_unet') + p.extra_generation_params["Hypertile U-Net max tile size"] = option('hypertile_max_tile_unet') + p.extra_generation_params["Hypertile U-Net swap size"] = option('hypertile_swap_size_unet') + + if shared.opts.hypertile_enable_vae: + p.extra_generation_params["Hypertile VAE"] = True + p.extra_generation_params["Hypertile VAE max depth"] = option('hypertile_max_depth_vae') + p.extra_generation_params["Hypertile VAE max tile size"] = option('hypertile_max_tile_vae') + p.extra_generation_params["Hypertile VAE swap size"] = option('hypertile_swap_size_vae') def configure_hypertile(width, height, enable_unet=True): @@ -57,16 +88,16 @@ def on_ui_settings(): benefit. """), - "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"), - "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"), - "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), - "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), + "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net", infotext="Hypertile U-Net").info("enables hypertile for all modes, including hires fix second pass; noticeable change in details of the generated picture"), + "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass", infotext="Hypertile U-Net second pass").info("enables hypertile just for hires fix second pass - regardless of whether the above setting is enabled"), + "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile U-Net max depth").info("larger = more neural network layers affected; minor effect on performance"), + "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-Net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile U-Net max tile size").info("larger = worse performance"), + "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-Net swap size", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile U-Net swap size"), - "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"), - "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}), - "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), - "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}), + "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE", infotext="Hypertile VAE").info("minimal change in the generated picture"), + "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}, infotext="Hypertile VAE max depth"), + "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, infotext="Hypertile VAE max tile size"), + "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, infotext="Hypertile VAE swap size"), } for name, opt in options.items(): -- cgit v1.2.3 From 16bdcce92d5b482d50cdc32a8f308040d320b6c9 Mon Sep 17 00:00:00 2001 From: Rene Kroon Date: Fri, 8 Dec 2023 21:19:29 +0100 Subject: #13354: solve lora loading issue --- extensions-builtin/Lora/networks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 7f814706..629bf853 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -159,7 +159,8 @@ def load_network(name, network_on_disk): bundle_embeddings = {} for key_network, weight in sd.items(): - key_network_without_network_parts, network_part = key_network.split(".", 1) + key_network_without_network_parts, _, network_part = key_network.partition(".") + if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) -- cgit v1.2.3 From 1a79a5049bdfef285235e83f37b201e39dd54f81 Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sat, 9 Dec 2023 22:35:31 +0600 Subject: Assign id for "extra_options". Replace numeric field with slider in Settings. --- .../extra-options-section/scripts/extra_options_section.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a903df62..b9867fe6 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -23,11 +23,12 @@ class ExtraOptionsSection(scripts.Script): self.setting_names = [] self.infotext_fields = [] extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: - with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and extra_options else gr.Group(): + with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) @@ -70,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i """), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 6}).needs_reload_ui(), "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) -- cgit v1.2.3 From 6b8143a84e112f029ee1868b6ab98b1d2c773ead Mon Sep 17 00:00:00 2001 From: kaalibro Date: Sun, 10 Dec 2023 15:35:06 +0600 Subject: Number of columns slider: max count set to 20, add description info --- .../extra-options-section/scripts/extra_options_section.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index b9867fe6..ac2c3de4 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -71,7 +71,7 @@ This page allows you to add some settings to the main interface of txt2img and i """), "extra_options_txt2img": shared.OptionInfo([], "Settings for txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), "extra_options_img2img": shared.OptionInfo([], "Settings for img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), - "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 6}).needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Number of columns for added settings", gr.Slider, {"step": 1, "minimum": 1, "maximum": 20}).info("displayed amount will depend on the actual browser window width").needs_reload_ui(), "extra_options_accordion": shared.OptionInfo(False, "Place added settings into an accordion").needs_reload_ui() })) -- cgit v1.2.3 From 735c9e8059384d4f640e5582413c30871f83eac5 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:38:32 +0800 Subject: Fix network_oft --- extensions-builtin/Lora/network_oft.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 05c37811..44465f7a 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -53,12 +53,17 @@ class NetworkModuleOFT(network.NetworkModule): self.constraint = None self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) - def calc_updown_kb(self, orig_weight, multiplier): + def calc_updown(self, orig_weight): + I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + if self.is_kohya: + block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - R = R * multiplier + torch.eye(self.block_size, device=orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -70,15 +75,10 @@ class NetworkModuleOFT(network.NetworkModule): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - def calc_updown(self, orig_weight): - # if alpha is a very small number as in coft, calc_scale() will return a almost zero number so we ignore it - multiplier = self.multiplier() - return self.calc_updown_kb(orig_weight, multiplier) - - # override to remove the multiplier/scale factor; it's already multiplied in get_weight def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): if self.bias is not None: updown = updown.reshape(self.bias.shape) @@ -94,4 +94,5 @@ class NetworkModuleOFT(network.NetworkModule): if ex_bias is not None: ex_bias = ex_bias * self.multiplier() - return updown, ex_bias + # Ignore calc_scale, which is not used in OFT. + return updown * self.multiplier(), ex_bias -- cgit v1.2.3 From 265bc26c21264d63956e8f30f1ce31dec917fc76 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:43:24 +0800 Subject: Use self.scale instead of custom finalize --- extensions-builtin/Lora/network_oft.py | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 44465f7a..e3ae61a2 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -78,21 +80,3 @@ class NetworkModuleOFT(network.NetworkModule): print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - # Ignore calc_scale, which is not used in OFT. - return updown * self.multiplier(), ex_bias -- cgit v1.2.3 From 8fc67f3851babd4575d3312b931d5e7c2b0c78c6 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:44:49 +0800 Subject: remove debug print --- extensions-builtin/Lora/network_oft.py | 1 - 1 file changed, 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index e3ae61a2..ff4eb59b 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -77,6 +77,5 @@ class NetworkModuleOFT(network.NetworkModule): merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight - print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) -- cgit v1.2.3 From 3772a82a70769fe1aac884a75bf5a3313fb83328 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 14 Dec 2023 01:47:13 +0800 Subject: better naming and correct order for device. --- extensions-builtin/Lora/network_oft.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index ff4eb59b..fa647020 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,14 +56,15 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - I = torch.eye(self.block_size, device=self.oft_blocks.device) oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + eye = torch.eye(self.block_size, device=self.oft_blocks.device) + if self.is_kohya: block_Q = oft_blocks - oft_blocks.transpose(1, 2) # ensure skew-symmetric orthogonal matrix norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - oft_blocks = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) -- cgit v1.2.3 From 93eae69895c34361a71dbed17348bcfd132fbc6a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 16 Dec 2023 11:00:42 +0300 Subject: move soft inpainting to a built-in extension --- .../soft-inpainting/scripts/soft_inpainting.py | 747 +++++++++++++++++++++ scripts/soft_inpainting.py | 747 --------------------- 2 files changed, 747 insertions(+), 747 deletions(-) create mode 100644 extensions-builtin/soft-inpainting/scripts/soft_inpainting.py delete mode 100644 scripts/soft_inpainting.py (limited to 'extensions-builtin') diff --git a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py new file mode 100644 index 00000000..d9024344 --- /dev/null +++ b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py @@ -0,0 +1,747 @@ +import numpy as np +import gradio as gr +import math +from modules.ui_components import InputAccordion +import modules.scripts as scripts + + +class SoftInpaintingSettings: + def __init__(self, + mask_blend_power, + mask_blend_scale, + inpaint_detail_preservation, + composite_mask_influence, + composite_difference_threshold, + composite_difference_contrast): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + self.composite_mask_influence = composite_mask_influence + self.composite_difference_threshold = composite_difference_threshold + self.composite_difference_contrast = composite_difference_contrast + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence + dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold + dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast + + +# ------------------- Methods ------------------- + +def processing_uses_inpainting(p): + # TODO: Figure out a better way to determine if inpainting is being used by p + if getattr(p, "image_mask", None) is not None: + return True + + if getattr(p, "mask", None) is not None: + return True + + if getattr(p, "nmask", None) is not None: + return True + + return False + + +def latent_blend(settings, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t2 + b_scaled = b * t2 + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled, t2, one_minus_t2 + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + settings.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + settings.inpaint_detail_preservation) * t3 + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation) + del a_magnitude, b_magnitude, t3, one_minus_t3 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(settings, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale) + + +def apply_adaptive_masks( + settings: SoftInpaintingSettings, + nmask, + latent_orig, + latent_processed, + overlay_images, + width, height, + paste_to): + import torch + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + latent_mask = nmask[0].float() + # convert the original mask into a form we use to scale distances for thresholding + mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) + mask_scalar = (0.5 * (1 - settings.composite_mask_influence) + + mask_scalar * settings.composite_mask_influence) + mask_scalar = mask_scalar / (1.00001 - mask_scalar) + mask_scalar = mask_scalar.cpu().numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + masks_for_overlay = [] + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + half_weighted_distance = settings.composite_difference_threshold * mask_scalar + converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast) + converted_mask = smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay.append(converted_mask) + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def apply_masks( + settings, + nmask, + overlay_images, + width, height, + paste_to): + import torch + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + masks_for_overlay = [] + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x *= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2) + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost", + "Mask influence", + "Difference threshold", + "Difference contrast") + +ui_info = SoftInpaintingSettings( + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.", + "How strongly the original mask should bias the difference threshold.", + "How much an image region can change before the original pixels are not blended in anymore.", + "How sharp the transition should be between blended and not blended.") + +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost", + "Soft inpainting mask influence", + "Soft inpainting difference threshold", + "Soft inpainting difference contrast") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation", + "composite_mask_influence", + "composite_difference_threshold", + "composite_difference_contrast") + + +# ------------------- Script ------------------- + + +class Script(scripts.Script): + def __init__(self): + self.section = "inpaint" + self.masks_for_overlay = None + self.overlay_images = None + + def title(self): + return "Soft Inpainting" + + def show(self, is_img2img): + return scripts.AlwaysVisible if is_img2img else False + + def ui(self, is_img2img): + if not is_img2img: + return + + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + power = \ + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power) + scale = \ + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale) + detail = \ + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation) + + gr.Markdown( + """ + ### Pixel Composite Settings + """) + + mask_inf = \ + gr.Slider(label=ui_labels.composite_mask_influence, + info=ui_info.composite_mask_influence, + minimum=0, + maximum=1, + step=0.05, + value=default.composite_mask_influence, + elem_id=el_ids.composite_mask_influence) + + dif_thresh = \ + gr.Slider(label=ui_labels.composite_difference_threshold, + info=ui_info.composite_difference_threshold, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_threshold, + elem_id=el_ids.composite_difference_threshold) + + dif_contr = \ + gr.Slider(label=ui_labels.composite_difference_contrast, + info=ui_info.composite_difference_contrast, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_contrast, + elem_id=el_ids.composite_difference_contrast) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + gr.Markdown( + """ + ## Pixel Composite Settings + + Masks are generated based on how much a part of the image changed after denoising. + These masks are used to blend the original and final images together. + If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_mask_influence} + + This parameter controls how much the mask should bias this sensitivity to difference. + + - **0**: Ignore the mask, only consider differences in image content. + - **1**: Follow the mask closely despite image content changes. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_threshold} + + This value represents the difference at which the original pixels will have less than 50% opacity. + + - **Low values**: Two images patches must be almost the same in order to retain original pixels. + - **High values**: Two images patches can be very different and still retain original pixels. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_contrast} + + This value represents the contrast between the opacity of the original and inpainted content. + + - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting. + - **High values**: Ghosting will be less common, but transitions may be very sudden. + """) + + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), + (power, gen_param_labels.mask_blend_power), + (scale, gen_param_labels.mask_blend_scale), + (detail, gen_param_labels.inpaint_detail_preservation), + (mask_inf, gen_param_labels.composite_mask_influence), + (dif_thresh, gen_param_labels.composite_difference_threshold), + (dif_contr, gen_param_labels.composite_difference_contrast)] + + self.paste_field_names = [] + for _, field_name in self.infotext_fields: + self.paste_field_names.append(field_name) + + return [soft_inpainting_enabled, + power, + scale, + detail, + mask_inf, + dif_thresh, + dif_contr] + + def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + # Shut off the rounding it normally does. + p.mask_round = False + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # p.extra_generation_params["Mask rounding"] = False + settings.add_generation_params(p.extra_generation_params) + + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + if mba.is_final_blend: + mba.blended_latent = mba.current_latent + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # todo: Why is sigma 2D? Both values are the same. + mba.blended_latent = latent_blend(settings, + mba.init_latent, + mba.current_latent, + get_modified_nmask(settings, mba.nmask, mba.sigma[0])) + + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + nmask = getattr(p, "nmask", None) + if nmask is None: + return + + from modules import images + from modules.shared import opts + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # since the original code puts holes in the existing overlay images, + # we have to rebuild them. + self.overlay_images = [] + for img in p.init_images: + + image = images.flatten(img, opts.img2img_background_color) + + if p.paste_to is None and p.resize_mode != 3: + image = images.resize_image(p.resize_mode, image, p.width, p.height) + + self.overlay_images.append(image.convert('RGBA')) + + if len(p.init_images) == 1: + self.overlay_images = self.overlay_images * p.batch_size + + if getattr(ps.samples, 'already_decoded', False): + self.masks_for_overlay = apply_masks(settings=settings, + nmask=nmask, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + else: + self.masks_for_overlay = apply_adaptive_masks(settings=settings, + nmask=nmask, + latent_orig=p.init_latent, + latent_processed=ps.samples, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, + detail_preservation, mask_inf, dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + if self.masks_for_overlay is None: + return + + if self.overlay_images is None: + return + + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] + ppmo.overlay_image = self.overlay_images[ppmo.index] diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py deleted file mode 100644 index d9024344..00000000 --- a/scripts/soft_inpainting.py +++ /dev/null @@ -1,747 +0,0 @@ -import numpy as np -import gradio as gr -import math -from modules.ui_components import InputAccordion -import modules.scripts as scripts - - -class SoftInpaintingSettings: - def __init__(self, - mask_blend_power, - mask_blend_scale, - inpaint_detail_preservation, - composite_mask_influence, - composite_difference_threshold, - composite_difference_contrast): - self.mask_blend_power = mask_blend_power - self.mask_blend_scale = mask_blend_scale - self.inpaint_detail_preservation = inpaint_detail_preservation - self.composite_mask_influence = composite_mask_influence - self.composite_difference_threshold = composite_difference_threshold - self.composite_difference_contrast = composite_difference_contrast - - def add_generation_params(self, dest): - dest[enabled_gen_param_label] = True - dest[gen_param_labels.mask_blend_power] = self.mask_blend_power - dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale - dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation - dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence - dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold - dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast - - -# ------------------- Methods ------------------- - -def processing_uses_inpainting(p): - # TODO: Figure out a better way to determine if inpainting is being used by p - if getattr(p, "image_mask", None) is not None: - return True - - if getattr(p, "mask", None) is not None: - return True - - if getattr(p, "nmask", None) is not None: - return True - - return False - - -def latent_blend(settings, a, b, t): - """ - Interpolates two latent image representations according to the parameter t, - where the interpolated vectors' magnitudes are also interpolated separately. - The "detail_preservation" factor biases the magnitude interpolation towards - the larger of the two magnitudes. - """ - import torch - - # NOTE: We use inplace operations wherever possible. - - # [4][w][h] to [1][4][w][h] - t2 = t.unsqueeze(0) - # [4][w][h] to [1][1][w][h] - the [4] seem redundant. - t3 = t[0].unsqueeze(0).unsqueeze(0) - - one_minus_t2 = 1 - t2 - one_minus_t3 = 1 - t3 - - # Linearly interpolate the image vectors. - a_scaled = a * one_minus_t2 - b_scaled = b * t2 - image_interp = a_scaled - image_interp.add_(b_scaled) - result_type = image_interp.dtype - del a_scaled, b_scaled, t2, one_minus_t2 - - # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) - # 64-bit operations are used here to allow large exponents. - current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) - - # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). - a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( - settings.inpaint_detail_preservation) * one_minus_t3 - b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( - settings.inpaint_detail_preservation) * t3 - desired_magnitude = a_magnitude - desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation) - del a_magnitude, b_magnitude, t3, one_minus_t3 - - # Change the linearly interpolated image vectors' magnitudes to the value we want. - # This is the last 64-bit operation. - image_interp_scaling_factor = desired_magnitude - image_interp_scaling_factor.div_(current_magnitude) - image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) - image_interp_scaled = image_interp - image_interp_scaled.mul_(image_interp_scaling_factor) - del current_magnitude - del desired_magnitude - del image_interp - del image_interp_scaling_factor - del result_type - - return image_interp_scaled - - -def get_modified_nmask(settings, nmask, sigma): - """ - Converts a negative mask representing the transparency of the original latent vectors being overlayed - to a mask that is scaled according to the denoising strength for this step. - - Where: - 0 = fully opaque, infinite density, fully masked - 1 = fully transparent, zero density, fully unmasked - - We bring this transparency to a power, as this allows one to simulate N number of blending operations - where N can be any positive real value. Using this one can control the balance of influence between - the denoiser and the original latents according to the sigma value. - - NOTE: "mask" is not used - """ - import torch - return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale) - - -def apply_adaptive_masks( - settings: SoftInpaintingSettings, - nmask, - latent_orig, - latent_processed, - overlay_images, - width, height, - paste_to): - import torch - import modules.processing as proc - import modules.images as images - from PIL import Image, ImageOps, ImageFilter - - # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. - latent_mask = nmask[0].float() - # convert the original mask into a form we use to scale distances for thresholding - mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) - mask_scalar = (0.5 * (1 - settings.composite_mask_influence) - + mask_scalar * settings.composite_mask_influence) - mask_scalar = mask_scalar / (1.00001 - mask_scalar) - mask_scalar = mask_scalar.cpu().numpy() - - latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) - - kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2) - - masks_for_overlay = [] - - for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): - converted_mask = distance_map.float().cpu().numpy() - converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.9, percentile_max=1, min_width=1) - converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, - percentile_min=0.25, percentile_max=0.75, min_width=1) - - # The distance at which opacity of original decreases to 50% - half_weighted_distance = settings.composite_difference_threshold * mask_scalar - converted_mask = converted_mask / half_weighted_distance - - converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast) - converted_mask = smootherstep(converted_mask) - converted_mask = 1 - converted_mask - converted_mask = 255. * converted_mask - converted_mask = converted_mask.astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, width, height) - converted_mask = proc.create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if paste_to is not None: - converted_mask = proc.uncrop(converted_mask, - (overlay_image.width, overlay_image.height), - paste_to) - - masks_for_overlay.append(converted_mask) - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - overlay_images[i] = image_masked.convert('RGBA') - - return masks_for_overlay - - -def apply_masks( - settings, - nmask, - overlay_images, - width, height, - paste_to): - import torch - import modules.processing as proc - import modules.images as images - from PIL import Image, ImageOps, ImageFilter - - converted_mask = nmask[0].float() - converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2) - converted_mask = 255. * converted_mask - converted_mask = converted_mask.cpu().numpy().astype(np.uint8) - converted_mask = Image.fromarray(converted_mask) - converted_mask = images.resize_image(2, converted_mask, width, height) - converted_mask = proc.create_binary_mask(converted_mask, round=False) - - # Remove aliasing artifacts using a gaussian blur. - converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) - - # Expand the mask to fit the whole image if needed. - if paste_to is not None: - converted_mask = proc.uncrop(converted_mask, - (width, height), - paste_to) - - masks_for_overlay = [] - - for i, overlay_image in enumerate(overlay_images): - masks_for_overlay[i] = converted_mask - - image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) - image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), - mask=ImageOps.invert(converted_mask.convert('L'))) - - overlay_images[i] = image_masked.convert('RGBA') - - return masks_for_overlay - - -def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): - """ - Generalization convolution filter capable of applying - weighted mean, median, maximum, and minimum filters - parametrically using an arbitrary kernel. - - Args: - img (nparray): - The image, a 2-D array of floats, to which the filter is being applied. - kernel (nparray): - The kernel, a 2-D array of floats. - kernel_center (nparray): - The kernel center coordinate, a 1-D array with two elements. - percentile_min (float): - The lower bound of the histogram window used by the filter, - from 0 to 1. - percentile_max (float): - The upper bound of the histogram window used by the filter, - from 0 to 1. - min_width (float): - The minimum size of the histogram window bounds, in weight units. - Must be greater than 0. - - Returns: - (nparray): A filtered copy of the input image "img", a 2-D array of floats. - """ - - # Converts an index tuple into a vector. - def vec(x): - return np.array(x) - - kernel_min = -kernel_center - kernel_max = vec(kernel.shape) - kernel_center - - def weighted_histogram_filter_single(idx): - idx = vec(idx) - min_index = np.maximum(0, idx + kernel_min) - max_index = np.minimum(vec(img.shape), idx + kernel_max) - window_shape = max_index - min_index - - class WeightedElement: - """ - An element of the histogram, its weight - and bounds. - """ - - def __init__(self, value, weight): - self.value: float = value - self.weight: float = weight - self.window_min: float = 0.0 - self.window_max: float = 1.0 - - # Collect the values in the image as WeightedElements, - # weighted by their corresponding kernel values. - values = [] - for window_tup in np.ndindex(tuple(window_shape)): - window_index = vec(window_tup) - image_index = window_index + min_index - centered_kernel_index = image_index - idx - kernel_index = centered_kernel_index + kernel_center - element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) - values.append(element) - - def sort_key(x: WeightedElement): - return x.value - - values.sort(key=sort_key) - - # Calculate the height of the stack (sum) - # and each sample's range they occupy in the stack - sum = 0 - for i in range(len(values)): - values[i].window_min = sum - sum += values[i].weight - values[i].window_max = sum - - # Calculate what range of this stack ("window") - # we want to get the weighted average across. - window_min = sum * percentile_min - window_max = sum * percentile_max - window_width = window_max - window_min - - # Ensure the window is within the stack and at least a certain size. - if window_width < min_width: - window_center = (window_min + window_max) / 2 - window_min = window_center - min_width / 2 - window_max = window_center + min_width / 2 - - if window_max > sum: - window_max = sum - window_min = sum - min_width - - if window_min < 0: - window_min = 0 - window_max = min_width - - value = 0 - value_weight = 0 - - # Get the weighted average of all the samples - # that overlap with the window, weighted - # by the size of their overlap. - for i in range(len(values)): - if window_min >= values[i].window_max: - continue - if window_max <= values[i].window_min: - break - - s = max(window_min, values[i].window_min) - e = min(window_max, values[i].window_max) - w = e - s - - value += values[i].value * w - value_weight += w - - return value / value_weight if value_weight != 0 else 0 - - img_out = img.copy() - - # Apply the kernel operation over each pixel. - for index in np.ndindex(img.shape): - img_out[index] = weighted_histogram_filter_single(index) - - return img_out - - -def smoothstep(x): - """ - The smoothstep function, input should be clamped to 0-1 range. - Turns a diagonal line (f(x) = x) into a sigmoid-like curve. - """ - return x * x * (3 - 2 * x) - - -def smootherstep(x): - """ - The smootherstep function, input should be clamped to 0-1 range. - Turns a diagonal line (f(x) = x) into a sigmoid-like curve. - """ - return x * x * x * (x * (6 * x - 15) + 10) - - -def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): - """ - Creates a Gaussian kernel with thresholded edges. - - Args: - stddev_radius (float): - Standard deviation of the gaussian kernel, in pixels. - max_radius (int): - The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. - The kernel is thresholded so that any values one pixel beyond this radius - is weighted at 0. - - Returns: - (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) - """ - - # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. - def gaussian(sqr_mag): - return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) - - # Helper function for converting a tuple to an array. - def vec(x): - return np.array(x) - - """ - Since a gaussian is unbounded, we need to limit ourselves - to a finite range. - We taper the ends off at the end of that range so they equal zero - while preserving the maximum value of 1 at the mean. - """ - zero_radius = max_radius + 1.0 - gauss_zero = gaussian(zero_radius * zero_radius) - gauss_kernel_scale = 1 / (1 - gauss_zero) - - def gaussian_kernel_func(coordinate): - x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 - x = gaussian(x) - x -= gauss_zero - x *= gauss_kernel_scale - x = max(0.0, x) - return x - - size = max_radius * 2 + 1 - kernel_center = max_radius - kernel = np.zeros((size, size)) - - for index in np.ndindex(kernel.shape): - kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) - - return kernel, kernel_center - - -# ------------------- Constants ------------------- - - -default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2) - -enabled_ui_label = "Soft inpainting" -enabled_gen_param_label = "Soft inpainting enabled" -enabled_el_id = "soft_inpainting_enabled" - -ui_labels = SoftInpaintingSettings( - "Schedule bias", - "Preservation strength", - "Transition contrast boost", - "Mask influence", - "Difference threshold", - "Difference contrast") - -ui_info = SoftInpaintingSettings( - "Shifts when preservation of original content occurs during denoising.", - "How strongly partially masked content should be preserved.", - "Amplifies the contrast that may be lost in partially masked regions.", - "How strongly the original mask should bias the difference threshold.", - "How much an image region can change before the original pixels are not blended in anymore.", - "How sharp the transition should be between blended and not blended.") - -gen_param_labels = SoftInpaintingSettings( - "Soft inpainting schedule bias", - "Soft inpainting preservation strength", - "Soft inpainting transition contrast boost", - "Soft inpainting mask influence", - "Soft inpainting difference threshold", - "Soft inpainting difference contrast") - -el_ids = SoftInpaintingSettings( - "mask_blend_power", - "mask_blend_scale", - "inpaint_detail_preservation", - "composite_mask_influence", - "composite_difference_threshold", - "composite_difference_contrast") - - -# ------------------- Script ------------------- - - -class Script(scripts.Script): - def __init__(self): - self.section = "inpaint" - self.masks_for_overlay = None - self.overlay_images = None - - def title(self): - return "Soft Inpainting" - - def show(self, is_img2img): - return scripts.AlwaysVisible if is_img2img else False - - def ui(self, is_img2img): - if not is_img2img: - return - - with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: - with gr.Group(): - gr.Markdown( - """ - Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. - **High _Mask blur_** values are recommended! - """) - - power = \ - gr.Slider(label=ui_labels.mask_blend_power, - info=ui_info.mask_blend_power, - minimum=0, - maximum=8, - step=0.1, - value=default.mask_blend_power, - elem_id=el_ids.mask_blend_power) - scale = \ - gr.Slider(label=ui_labels.mask_blend_scale, - info=ui_info.mask_blend_scale, - minimum=0, - maximum=8, - step=0.05, - value=default.mask_blend_scale, - elem_id=el_ids.mask_blend_scale) - detail = \ - gr.Slider(label=ui_labels.inpaint_detail_preservation, - info=ui_info.inpaint_detail_preservation, - minimum=1, - maximum=32, - step=0.5, - value=default.inpaint_detail_preservation, - elem_id=el_ids.inpaint_detail_preservation) - - gr.Markdown( - """ - ### Pixel Composite Settings - """) - - mask_inf = \ - gr.Slider(label=ui_labels.composite_mask_influence, - info=ui_info.composite_mask_influence, - minimum=0, - maximum=1, - step=0.05, - value=default.composite_mask_influence, - elem_id=el_ids.composite_mask_influence) - - dif_thresh = \ - gr.Slider(label=ui_labels.composite_difference_threshold, - info=ui_info.composite_difference_threshold, - minimum=0, - maximum=8, - step=0.25, - value=default.composite_difference_threshold, - elem_id=el_ids.composite_difference_threshold) - - dif_contr = \ - gr.Slider(label=ui_labels.composite_difference_contrast, - info=ui_info.composite_difference_contrast, - minimum=0, - maximum=8, - step=0.25, - value=default.composite_difference_contrast, - elem_id=el_ids.composite_difference_contrast) - - with gr.Accordion("Help", open=False): - gr.Markdown( - f""" - ### {ui_labels.mask_blend_power} - - The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). - This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. - This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. - - - **Below 1**: Stronger preservation near the end (with low sigma) - - **1**: Balanced (proportional to sigma) - - **Above 1**: Stronger preservation in the beginning (with high sigma) - """) - gr.Markdown( - f""" - ### {ui_labels.mask_blend_scale} - - Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. - This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. - - - **Low values**: Favors generated content. - - **High values**: Favors original content. - """) - gr.Markdown( - f""" - ### {ui_labels.inpaint_detail_preservation} - - This parameter controls how the original latent vectors and denoised latent vectors are interpolated. - With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. - This can prevent the loss of contrast that occurs with linear interpolation. - - - **Low values**: Softer blending, details may fade. - - **High values**: Stronger contrast, may over-saturate colors. - """) - - gr.Markdown( - """ - ## Pixel Composite Settings - - Masks are generated based on how much a part of the image changed after denoising. - These masks are used to blend the original and final images together. - If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. - """) - - gr.Markdown( - f""" - ### {ui_labels.composite_mask_influence} - - This parameter controls how much the mask should bias this sensitivity to difference. - - - **0**: Ignore the mask, only consider differences in image content. - - **1**: Follow the mask closely despite image content changes. - """) - - gr.Markdown( - f""" - ### {ui_labels.composite_difference_threshold} - - This value represents the difference at which the original pixels will have less than 50% opacity. - - - **Low values**: Two images patches must be almost the same in order to retain original pixels. - - **High values**: Two images patches can be very different and still retain original pixels. - """) - - gr.Markdown( - f""" - ### {ui_labels.composite_difference_contrast} - - This value represents the contrast between the opacity of the original and inpainted content. - - - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting. - - **High values**: Ghosting will be less common, but transitions may be very sudden. - """) - - self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), - (power, gen_param_labels.mask_blend_power), - (scale, gen_param_labels.mask_blend_scale), - (detail, gen_param_labels.inpaint_detail_preservation), - (mask_inf, gen_param_labels.composite_mask_influence), - (dif_thresh, gen_param_labels.composite_difference_threshold), - (dif_contr, gen_param_labels.composite_difference_contrast)] - - self.paste_field_names = [] - for _, field_name in self.infotext_fields: - self.paste_field_names.append(field_name) - - return [soft_inpainting_enabled, - power, - scale, - detail, - mask_inf, - dif_thresh, - dif_contr] - - def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): - if not enabled: - return - - if not processing_uses_inpainting(p): - return - - # Shut off the rounding it normally does. - p.mask_round = False - - settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) - - # p.extra_generation_params["Mask rounding"] = False - settings.add_generation_params(p.extra_generation_params) - - def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, - dif_thresh, dif_contr): - if not enabled: - return - - if not processing_uses_inpainting(p): - return - - if mba.is_final_blend: - mba.blended_latent = mba.current_latent - return - - settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) - - # todo: Why is sigma 2D? Both values are the same. - mba.blended_latent = latent_blend(settings, - mba.init_latent, - mba.current_latent, - get_modified_nmask(settings, mba.nmask, mba.sigma[0])) - - def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, - dif_thresh, dif_contr): - if not enabled: - return - - if not processing_uses_inpainting(p): - return - - nmask = getattr(p, "nmask", None) - if nmask is None: - return - - from modules import images - from modules.shared import opts - - settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) - - # since the original code puts holes in the existing overlay images, - # we have to rebuild them. - self.overlay_images = [] - for img in p.init_images: - - image = images.flatten(img, opts.img2img_background_color) - - if p.paste_to is None and p.resize_mode != 3: - image = images.resize_image(p.resize_mode, image, p.width, p.height) - - self.overlay_images.append(image.convert('RGBA')) - - if len(p.init_images) == 1: - self.overlay_images = self.overlay_images * p.batch_size - - if getattr(ps.samples, 'already_decoded', False): - self.masks_for_overlay = apply_masks(settings=settings, - nmask=nmask, - overlay_images=self.overlay_images, - width=p.width, - height=p.height, - paste_to=p.paste_to) - else: - self.masks_for_overlay = apply_adaptive_masks(settings=settings, - nmask=nmask, - latent_orig=p.init_latent, - latent_processed=ps.samples, - overlay_images=self.overlay_images, - width=p.width, - height=p.height, - paste_to=p.paste_to) - - def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, - detail_preservation, mask_inf, dif_thresh, dif_contr): - if not enabled: - return - - if not processing_uses_inpainting(p): - return - - if self.masks_for_overlay is None: - return - - if self.overlay_images is None: - return - - ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] - ppmo.overlay_image = self.overlay_images[ppmo.index] -- cgit v1.2.3 From 59d060fd5ea93fcc3fdbfbd13b6e20fda06ecf94 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 30 Dec 2023 17:11:03 +0900 Subject: More lora not found warning --- extensions-builtin/Lora/networks.py | 8 +++++++- extensions-builtin/Lora/scripts/lora_script.py | 2 ++ 2 files changed, 9 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 985b2753..72ebd624 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,3 +1,4 @@ +import gradio as gr import logging import os import re @@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No emb_db.skipped_embeddings[name] = embedding if failed_to_load_networks: - sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) + lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}' + sd_hijack.model_hijack.comments.append(lora_not_found_message) + if shared.opts.lora_not_found_warning_console: + print(f'\n{lora_not_found_message}\n') + if shared.opts.lora_not_found_gradio_warning: + gr.Warning(lora_not_found_message) purge_networks_from_memory() diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index ef23968c..1518f7e5 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), + "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"), + "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"), })) -- cgit v1.2.3 From b0f59342346b1c8b405f97c0e0bb01c6ae05c601 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 14:43:51 +0200 Subject: Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR) --- extensions-builtin/ScuNET/scripts/scunet_model.py | 13 +- extensions-builtin/ScuNET/scunet_model_arch.py | 268 ------ extensions-builtin/SwinIR/scripts/swinir_model.py | 126 ++- extensions-builtin/SwinIR/swinir_model_arch.py | 867 ------------------ extensions-builtin/SwinIR/swinir_model_arch_v2.py | 1017 --------------------- modules/codeformer/codeformer_arch.py | 276 ------ modules/codeformer/vqgan_arch.py | 435 --------- modules/codeformer_model.py | 195 ++-- modules/esrgan_model.py | 153 +--- modules/esrgan_model_arch.py | 465 ---------- modules/gfpgan_model.py | 13 +- modules/launch_utils.py | 7 - modules/modelloader.py | 16 + modules/paths.py | 1 - modules/realesrgan_model.py | 153 ++-- modules/sysinfo.py | 2 - modules/upscaler.py | 3 + requirements.txt | 3 +- requirements_versions.txt | 4 +- 19 files changed, 263 insertions(+), 3754 deletions(-) delete mode 100644 extensions-builtin/ScuNET/scunet_model_arch.py delete mode 100644 extensions-builtin/SwinIR/swinir_model_arch.py delete mode 100644 extensions-builtin/SwinIR/swinir_model_arch_v2.py delete mode 100644 modules/codeformer/codeformer_arch.py delete mode 100644 modules/codeformer/vqgan_arch.py delete mode 100644 modules/esrgan_model_arch.py (limited to 'extensions-builtin') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 167d2f64..18cf8e1a 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -7,9 +7,7 @@ from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet -from modules.modelloader import load_file_from_url from modules.shared import opts @@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): device = devices.get_device_for('scunet') if path.startswith("http"): # TODO: this doesn't use `path` at all? - filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") + filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for _, v in model.named_parameters(): - v.requires_grad = False - model = model.to(device) - - return model + return modelloader.load_spandrel_model(filename, device=device) def on_ui_settings(): diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py deleted file mode 100644 index b51a8806..00000000 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ /dev/null @@ -1,268 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_, DropPath - - -class WMSA(nn.Module): - """ Self-attention module in Swin Transformer - """ - - def __init__(self, input_dim, output_dim, head_dim, window_size, type): - super(WMSA, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.head_dim = head_dim - self.scale = self.head_dim ** -0.5 - self.n_heads = input_dim // head_dim - self.window_size = window_size - self.type = type - self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) - - self.relative_position_params = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) - - self.linear = nn.Linear(self.input_dim, self.output_dim) - - trunc_normal_(self.relative_position_params, std=.02) - self.relative_position_params = torch.nn.Parameter( - self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, - 2).transpose( - 0, 1)) - - def generate_mask(self, h, w, p, shift): - """ generating the mask of SW-MSA - Args: - shift: shift parameters in CyclicShift. - Returns: - attn_mask: should be (1 1 w p p), - """ - # supporting square. - attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) - if self.type == 'W': - return attn_mask - - s = p - shift - attn_mask[-1, :, :s, :, s:, :] = True - attn_mask[-1, :, s:, :, :s, :] = True - attn_mask[:, -1, :, :s, :, s:] = True - attn_mask[:, -1, :, s:, :, :s] = True - attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') - return attn_mask - - def forward(self, x): - """ Forward pass of Window Multi-head Self-attention module. - Args: - x: input tensor with shape of [b h w c]; - attn_mask: attention mask, fill -inf where the value is True; - Returns: - output: tensor shape [b h w c] - """ - if self.type != 'W': - x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) - - x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) - h_windows = x.size(1) - w_windows = x.size(2) - # square validation - # assert h_windows == w_windows - - x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) - qkv = self.embedding_layer(x) - q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) - sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale - # Adding learnable relative embedding - sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') - # Using Attn Mask to distinguish different subwindows. - if self.type != 'W': - attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) - sim = sim.masked_fill_(attn_mask, float("-inf")) - - probs = nn.functional.softmax(sim, dim=-1) - output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) - output = rearrange(output, 'h b w p c -> b w p (h c)') - output = self.linear(output) - output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - - if self.type != 'W': - output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2)) - - return output - - def relative_embedding(self): - cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) - relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 - # negative is allowed - return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] - - -class Block(nn.Module): - def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer Block - """ - super(Block, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - assert type in ['W', 'SW'] - self.type = type - if input_resolution <= window_size: - self.type = 'W' - - self.ln1 = nn.LayerNorm(input_dim) - self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.ln2 = nn.LayerNorm(input_dim) - self.mlp = nn.Sequential( - nn.Linear(input_dim, 4 * input_dim), - nn.GELU(), - nn.Linear(4 * input_dim, output_dim), - ) - - def forward(self, x): - x = x + self.drop_path(self.msa(self.ln1(x))) - x = x + self.drop_path(self.mlp(self.ln2(x))) - return x - - -class ConvTransBlock(nn.Module): - def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer and Conv Block - """ - super(ConvTransBlock, self).__init__() - self.conv_dim = conv_dim - self.trans_dim = trans_dim - self.head_dim = head_dim - self.window_size = window_size - self.drop_path = drop_path - self.type = type - self.input_resolution = input_resolution - - assert self.type in ['W', 'SW'] - if self.input_resolution <= self.window_size: - self.type = 'W' - - self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, - self.type, self.input_resolution) - self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - - self.conv_block = nn.Sequential( - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), - nn.ReLU(True), - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) - ) - - def forward(self, x): - conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) - conv_x = self.conv_block(conv_x) + conv_x - trans_x = Rearrange('b c h w -> b h w c')(trans_x) - trans_x = self.trans_block(trans_x) - trans_x = Rearrange('b h w c -> b c h w')(trans_x) - res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) - x = x + res - - return x - - -class SCUNet(nn.Module): - # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): - def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): - super(SCUNet, self).__init__() - if config is None: - config = [2, 2, 2, 2, 2, 2, 2] - self.config = config - self.dim = dim - self.head_dim = 32 - self.window_size = 8 - - # drop path rate for each layer - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] - - self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] - - begin = 0 - self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[0])] + \ - [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] - - begin += config[0] - self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[1])] + \ - [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] - - begin += config[1] - self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[2])] + \ - [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] - - begin += config[2] - self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 8) - for i in range(config[3])] - - begin += config[3] - self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[4])] - - begin += config[4] - self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[5])] - - begin += config[5] - self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[6])] - - self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] - - self.m_head = nn.Sequential(*self.m_head) - self.m_down1 = nn.Sequential(*self.m_down1) - self.m_down2 = nn.Sequential(*self.m_down2) - self.m_down3 = nn.Sequential(*self.m_down3) - self.m_body = nn.Sequential(*self.m_body) - self.m_up3 = nn.Sequential(*self.m_up3) - self.m_up2 = nn.Sequential(*self.m_up2) - self.m_up1 = nn.Sequential(*self.m_up1) - self.m_tail = nn.Sequential(*self.m_tail) - # self.apply(self._init_weights) - - def forward(self, x0): - - h, w = x0.size()[-2:] - paddingBottom = int(np.ceil(h / 64) * 64 - h) - paddingRight = int(np.ceil(w / 64) * 64 - w) - x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) - - x1 = self.m_head(x0) - x2 = self.m_down1(x1) - x3 = self.m_down2(x2) - x4 = self.m_down3(x3) - x = self.m_body(x4) - x = self.m_up3(x + x4) - x = self.m_up2(x + x3) - x = self.m_up1(x + x2) - x = self.m_tail(x + x1) - - x = x[..., :h, :w] - - return x - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index ae0d0e6a..85c18b9e 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,5 +1,5 @@ +import logging import sys -import platform import numpy as np import torch @@ -8,13 +8,11 @@ from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared from modules.shared import opts, state -from swinir_model_arch import SwinIR -from swinir_model_arch_v2 import Swin2SR from modules.upscaler import Upscaler, UpscalerData SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" -device_swinir = devices.get_device_for('swinir') +logger = logging.getLogger(__name__) class UpscalerSwinIR(Upscaler): @@ -37,26 +35,29 @@ class UpscalerSwinIR(Upscaler): scalers.append(model_data) self.scalers = scalers - def do_upscale(self, img, model_file): - use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \ - and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows" + def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: current_config = (model_file, opts.SWIN_tile) - if use_compile and self._cached_model_config == current_config: + device = self._get_device() + + if self._cached_model_config == current_config: model = self._cached_model else: - self._cached_model = None try: model = self.load_model(model_file) except Exception as e: print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) return img - model = model.to(device_swinir, dtype=devices.dtype) - if use_compile: - model = torch.compile(model) - self._cached_model = model - self._cached_model_config = current_config - img = upscale(img, model) + self._cached_model = model + self._cached_model_config = current_config + + img = upscale( + img, + model, + tile=opts.SWIN_tile, + tile_overlap=opts.SWIN_tile_overlap, + device=device, + ) devices.torch_gc() return img @@ -69,69 +70,54 @@ class UpscalerSwinIR(Upscaler): ) else: filename = path - if filename.endswith(".v2.pth"): - model = Swin2SR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="1conv", - ) - params = None - else: - model = SwinIR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="3conv", - ) - params = "params_ema" - pretrained_model = torch.load(filename) - if params is not None: - model.load_state_dict(pretrained_model[params], strict=True) - else: - model.load_state_dict(pretrained_model, strict=True) + model = modelloader.load_spandrel_model( + filename, + device=self._get_device(), + dtype=devices.dtype, + ) + if getattr(opts, 'SWIN_torch_compile', False): + try: + model = torch.compile(model) + except Exception: + logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) return model + def _get_device(self): + return devices.get_device_for('swinir') + def upscale( - img, - model, - tile=None, - tile_overlap=None, - window_size=8, - scale=4, + img, + model, + *, + tile: int, + tile_overlap: int, + window_size=8, + scale=4, + device, ): - tile = tile or opts.SWIN_tile - tile_overlap = tile_overlap or opts.SWIN_tile_overlap - img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) + img = img.unsqueeze(0).to(device, dtype=devices.dtype) with torch.no_grad(), devices.autocast(): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old w_pad = (w_old // window_size + 1) * window_size - w_old img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference(img, model, tile, tile_overlap, window_size, scale) + output = inference( + img, + model, + tile=tile, + tile_overlap=tile_overlap, + window_size=window_size, + scale=scale, + device=device, + ) output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() if output.ndim == 3: @@ -142,7 +128,16 @@ def upscale( return Image.fromarray(output, "RGB") -def inference(img, model, tile, tile_overlap, window_size, scale): +def inference( + img, + model, + *, + tile: int, + tile_overlap: int, + window_size: int, + scale: int, + device, +): # test the image tile by tile b, c, h, w = img.size() tile = min(tile, h, w) @@ -152,8 +147,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): stride = tile - tile_overlap h_idx_list = list(range(0, h - tile, stride)) + [h - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) + E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img) + W = torch.zeros_like(E, dtype=devices.dtype, device=device) with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: @@ -185,8 +180,7 @@ def on_ui_settings(): shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) - if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows - shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) + shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py deleted file mode 100644 index 93b93274..00000000 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ /dev/null @@ -1,867 +0,0 @@ -# ----------------------------------------------------------------------------------- -# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 -# Originally Written by Ze Liu, Modified by Jingyun Liang. -# ----------------------------------------------------------------------------------- - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - flops = 0 - H, W = self.img_size - if self.norm is not None: - flops += H * W * self.embed_dim - return flops - - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - -class SwinIR(nn.Module): - r""" SwinIR - A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(SwinIR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - if self.upscale == 4: - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - if self.upscale == 4: - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = SwinIR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py deleted file mode 100644 index dad22cca..00000000 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ /dev/null @@ -1,1017 +0,0 @@ -# ----------------------------------------------------------------------------------- -# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/ -# Written by Conde and Choi et al. -# ----------------------------------------------------------------------------------- - -import math -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - pretrained_window_size (tuple[int]): The height and width of the window in pre-training. - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., - pretrained_window_size=(0, 0)): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.pretrained_window_size = pretrained_window_size - self.num_heads = num_heads - - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) - - # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), - nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False)) - - # get relative_coords_table - relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) - relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) - relative_coords_table = torch.stack( - torch.meshgrid([relative_coords_h, - relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 - if pretrained_window_size[0] > 0: - relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) - else: - relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) - relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) - relative_coords_table *= 8 # normalize to -8, 8 - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - torch.abs(relative_coords_table) + 1.0) / np.log2(8) - - self.register_buffer("relative_coords_table", relative_coords_table) - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=False) - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(dim)) - self.v_bias = nn.Parameter(torch.zeros(dim)) - else: - self.q_bias = None - self.v_bias = None - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - # cosine attention - attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() - attn = attn * logit_scale - - relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, ' \ - f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - pretrained_window_size (int): Window size in pre-training. - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, - pretrained_window_size=to_2tuple(pretrained_window_size)) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - #assert L == H * W, "input feature has wrong size" - - shortcut = x - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - x = shortcut + self.drop_path(self.norm1(x)) - - # FFN - x = x + self.drop_path(self.norm2(self.mlp(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(2 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.reduction(x) - x = self.norm(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - flops += H * W * self.dim // 2 - return flops - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - pretrained_window_size (int): Local window size in pre-training. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - pretrained_window_size=0): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer, - pretrained_window_size=pretrained_window_size) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - def _init_respostnorm(self): - for blk in self.blocks: - nn.init.constant_(blk.norm1.bias, 0) - nn.init.constant_(blk.norm1.weight, 0) - nn.init.constant_(blk.norm2.bias, 0) - nn.init.constant_(blk.norm2.weight, 0) - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - # assert H == self.img_size[0] and W == self.img_size[1], - # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - Ho, Wo = self.patches_resolution - flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) - if self.norm is not None: - flops += Ho * Wo * self.embed_dim - return flops - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - -class Upsample_hf(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample_hf, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - - -class Swin2SR(nn.Module): - r""" Swin2SR - A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(Swin2SR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - - if self.upsampler == 'pixelshuffle_hf': - self.layers_hf = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers_hf.append(layer) - - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffle_aux': - self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) - self.conv_before_upsample = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_after_aux = nn.Sequential( - nn.Conv2d(3, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffle_hf': - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.upsample_hf = Upsample_hf(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - self.conv_before_upsample_hf = nn.Sequential( - nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - assert self.upscale == 4, 'only support x4 now.' - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward_features_hf(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers_hf: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffle_aux': - bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) - bicubic = self.conv_bicubic(bicubic) - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - aux = self.conv_aux(x) # b, 3, LR_H, LR_W - x = self.conv_after_aux(aux) - x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] - x = self.conv_last(x) - aux = aux / self.img_range + self.mean - elif self.upsampler == 'pixelshuffle_hf': - # for classical SR with HF - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x_before = self.conv_before_upsample(x) - x_out = self.conv_last(self.upsample(x_before)) - - x_hf = self.conv_first_hf(x_before) - x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf - x_hf = self.conv_before_upsample_hf(x_hf) - x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) - x = x_out + x_hf - x_hf = x_hf / self.img_range + self.mean - - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - if self.upsampler == "pixelshuffle_aux": - return x[:, :, :H*self.upscale, :W*self.upscale], aux - - elif self.upsampler == "pixelshuffle_hf": - x_out = x_out / self.img_range + self.mean - return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale] - - else: - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = Swin2SR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/modules/codeformer/codeformer_arch.py b/modules/codeformer/codeformer_arch.py deleted file mode 100644 index 12db6814..00000000 --- a/modules/codeformer/codeformer_arch.py +++ /dev/null @@ -1,276 +0,0 @@ -# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py - -import math -import torch -from torch import nn, Tensor -import torch.nn.functional as F -from typing import Optional - -from modules.codeformer.vqgan_arch import VQAutoEncoder, ResBlock -from basicsr.utils.registry import ARCH_REGISTRY - -def calc_mean_std(feat, eps=1e-5): - """Calculate mean and std for adaptive_instance_normalization. - - Args: - feat (Tensor): 4D tensor. - eps (float): A small value added to the variance to avoid - divide-by-zero. Default: 1e-5. - """ - size = feat.size() - assert len(size) == 4, 'The input feature should be 4D tensor.' - b, c = size[:2] - feat_var = feat.view(b, c, -1).var(dim=2) + eps - feat_std = feat_var.sqrt().view(b, c, 1, 1) - feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) - return feat_mean, feat_std - - -def adaptive_instance_normalization(content_feat, style_feat): - """Adaptive instance normalization. - - Adjust the reference features to have the similar color and illuminations - as those in the degradate features. - - Args: - content_feat (Tensor): The reference feature. - style_feat (Tensor): The degradate features. - """ - size = content_feat.size() - style_mean, style_std = calc_mean_std(style_feat) - content_mean, content_std = calc_mean_std(content_feat) - normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) - return normalized_feat * style_std.expand(size) + style_mean.expand(size) - - -class PositionEmbeddingSine(nn.Module): - """ - This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. - """ - - def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): - super().__init__() - self.num_pos_feats = num_pos_feats - self.temperature = temperature - self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale - - def forward(self, x, mask=None): - if mask is None: - mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) - not_mask = ~mask - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack( - (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos_y = torch.stack( - (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 - ).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos - -def _get_activation_fn(activation): - """Return an activation function given a string""" - if activation == "relu": - return F.relu - if activation == "gelu": - return F.gelu - if activation == "glu": - return F.glu - raise RuntimeError(F"activation should be relu/gelu, not {activation}.") - - -class TransformerSALayer(nn.Module): - def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): - super().__init__() - self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) - # Implementation of Feedforward model - MLP - self.linear1 = nn.Linear(embed_dim, dim_mlp) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_mlp, embed_dim) - - self.norm1 = nn.LayerNorm(embed_dim) - self.norm2 = nn.LayerNorm(embed_dim) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward(self, tgt, - tgt_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None): - - # self attention - tgt2 = self.norm1(tgt) - q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask)[0] - tgt = tgt + self.dropout1(tgt2) - - # ffn - tgt2 = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) - tgt = tgt + self.dropout2(tgt2) - return tgt - -class Fuse_sft_block(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.encode_enc = ResBlock(2*in_ch, out_ch) - - self.scale = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) - - self.shift = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.LeakyReLU(0.2, True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) - - def forward(self, enc_feat, dec_feat, w=1): - enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) - scale = self.scale(enc_feat) - shift = self.shift(enc_feat) - residual = w * (dec_feat * scale + shift) - out = dec_feat + residual - return out - - -@ARCH_REGISTRY.register() -class CodeFormer(VQAutoEncoder): - def __init__(self, dim_embd=512, n_head=8, n_layers=9, - codebook_size=1024, latent_size=256, - connect_list=('32', '64', '128', '256'), - fix_modules=('quantize', 'generator')): - super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) - - if fix_modules is not None: - for module in fix_modules: - for param in getattr(self, module).parameters(): - param.requires_grad = False - - self.connect_list = connect_list - self.n_layers = n_layers - self.dim_embd = dim_embd - self.dim_mlp = dim_embd*2 - - self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) - self.feat_emb = nn.Linear(256, self.dim_embd) - - # transformer - self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) - for _ in range(self.n_layers)]) - - # logits_predict head - self.idx_pred_layer = nn.Sequential( - nn.LayerNorm(dim_embd), - nn.Linear(dim_embd, codebook_size, bias=False)) - - self.channels = { - '16': 512, - '32': 256, - '64': 256, - '128': 128, - '256': 128, - '512': 64, - } - - # after second residual block for > 16, before attn layer for ==16 - self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} - # after first residual block for > 16, before attn layer for ==16 - self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} - - # fuse_convs_dict - self.fuse_convs_dict = nn.ModuleDict() - for f_size in self.connect_list: - in_ch = self.channels[f_size] - self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) - - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - module.weight.data.normal_(mean=0.0, std=0.02) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - - def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): - # ################### Encoder ##################### - enc_feat_dict = {} - out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] - for i, block in enumerate(self.encoder.blocks): - x = block(x) - if i in out_list: - enc_feat_dict[str(x.shape[-1])] = x.clone() - - lq_feat = x - # ################# Transformer ################### - # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) - pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) - # BCHW -> BC(HW) -> (HW)BC - feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) - query_emb = feat_emb - # Transformer encoder - for layer in self.ft_layers: - query_emb = layer(query_emb, query_pos=pos_emb) - - # output logits - logits = self.idx_pred_layer(query_emb) # (hw)bn - logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n - - if code_only: # for training stage II - # logits doesn't need softmax before cross_entropy loss - return logits, lq_feat - - # ################# Quantization ################### - # if self.training: - # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) - # # b(hw)c -> bc(hw) -> bchw - # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) - # ------------ - soft_one_hot = F.softmax(logits, dim=2) - _, top_idx = torch.topk(soft_one_hot, 1, dim=2) - quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) - # preserve gradients - # quant_feat = lq_feat + (quant_feat - lq_feat).detach() - - if detach_16: - quant_feat = quant_feat.detach() # for training stage III - if adain: - quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) - - # ################## Generator #################### - x = quant_feat - fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] - - for i, block in enumerate(self.generator.blocks): - x = block(x) - if i in fuse_list: # fuse after i-th block - f_size = str(x.shape[-1]) - if w>0: - x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) - out = x - # logits doesn't need softmax before cross_entropy loss - return out, logits, lq_feat diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py deleted file mode 100644 index 09ee6660..00000000 --- a/modules/codeformer/vqgan_arch.py +++ /dev/null @@ -1,435 +0,0 @@ -# this file is copied from CodeFormer repository. Please see comment in modules/codeformer_model.py - -''' -VQGAN code, adapted from the original created by the Unleashing Transformers authors: -https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py - -''' -import torch -import torch.nn as nn -import torch.nn.functional as F -from basicsr.utils import get_root_logger -from basicsr.utils.registry import ARCH_REGISTRY - -def normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -@torch.jit.script -def swish(x): - return x*torch.sigmoid(x) - - -# Define VQVAE classes -class VectorQuantizer(nn.Module): - def __init__(self, codebook_size, emb_dim, beta): - super(VectorQuantizer, self).__init__() - self.codebook_size = codebook_size # number of embeddings - self.emb_dim = emb_dim # dimension of embedding - self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 - self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) - self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) - - def forward(self, z): - # reshape z -> (batch, height, width, channel) and flatten - z = z.permute(0, 2, 3, 1).contiguous() - z_flattened = z.view(-1, self.emb_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ - 2 * torch.matmul(z_flattened, self.embedding.weight.t()) - - mean_distance = torch.mean(d) - # find closest encodings - # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) - min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) - # [0-1], higher score, higher confidence - min_encoding_scores = torch.exp(-min_encoding_scores/10) - - min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) - min_encodings.scatter_(1, min_encoding_indices, 1) - - # get quantized latent vectors - z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) - # compute loss for embedding - loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) - # preserve gradients - z_q = z + (z_q - z).detach() - - # perplexity - e_mean = torch.mean(min_encodings, dim=0) - perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q, loss, { - "perplexity": perplexity, - "min_encodings": min_encodings, - "min_encoding_indices": min_encoding_indices, - "min_encoding_scores": min_encoding_scores, - "mean_distance": mean_distance - } - - def get_codebook_feat(self, indices, shape): - # input indices: batch*token_num -> (batch*token_num)*1 - # shape: batch, height, width, channel - indices = indices.view(-1,1) - min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) - min_encodings.scatter_(1, indices, 1) - # get quantized latent vectors - z_q = torch.matmul(min_encodings.float(), self.embedding.weight) - - if shape is not None: # reshape back to match original input shape - z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() - - return z_q - - -class GumbelQuantizer(nn.Module): - def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): - super().__init__() - self.codebook_size = codebook_size # number of embeddings - self.emb_dim = emb_dim # dimension of embedding - self.straight_through = straight_through - self.temperature = temp_init - self.kl_weight = kl_weight - self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits - self.embed = nn.Embedding(codebook_size, emb_dim) - - def forward(self, z): - hard = self.straight_through if self.training else True - - logits = self.proj(z) - - soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) - - z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) - - # + kl divergence to the prior loss - qy = F.softmax(logits, dim=1) - diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() - min_encoding_indices = soft_one_hot.argmax(dim=1) - - return z_q, diff, { - "min_encoding_indices": min_encoding_indices - } - - -class Downsample(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) - - def forward(self, x): - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - return x - - -class Upsample(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) - - def forward(self, x): - x = F.interpolate(x, scale_factor=2.0, mode="nearest") - x = self.conv(x) - - return x - - -class ResBlock(nn.Module): - def __init__(self, in_channels, out_channels=None): - super(ResBlock, self).__init__() - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - self.norm1 = normalize(in_channels) - self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - self.norm2 = normalize(out_channels) - self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - if self.in_channels != self.out_channels: - self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, x_in): - x = x_in - x = self.norm1(x) - x = swish(x) - x = self.conv1(x) - x = self.norm2(x) - x = swish(x) - x = self.conv2(x) - if self.in_channels != self.out_channels: - x_in = self.conv_out(x_in) - - return x + x_in - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = q.reshape(b, c, h*w) - q = q.permute(0, 2, 1) - k = k.reshape(b, c, h*w) - w_ = torch.bmm(q, k) - w_ = w_ * (int(c)**(-0.5)) - w_ = F.softmax(w_, dim=2) - - # attend to values - v = v.reshape(b, c, h*w) - w_ = w_.permute(0, 2, 1) - h_ = torch.bmm(v, w_) - h_ = h_.reshape(b, c, h, w) - - h_ = self.proj_out(h_) - - return x+h_ - - -class Encoder(nn.Module): - def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): - super().__init__() - self.nf = nf - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.attn_resolutions = attn_resolutions - - curr_res = self.resolution - in_ch_mult = (1,)+tuple(ch_mult) - - blocks = [] - # initial convultion - blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) - - # residual and downsampling blocks, with attention on smaller res (16x16) - for i in range(self.num_resolutions): - block_in_ch = nf * in_ch_mult[i] - block_out_ch = nf * ch_mult[i] - for _ in range(self.num_res_blocks): - blocks.append(ResBlock(block_in_ch, block_out_ch)) - block_in_ch = block_out_ch - if curr_res in attn_resolutions: - blocks.append(AttnBlock(block_in_ch)) - - if i != self.num_resolutions - 1: - blocks.append(Downsample(block_in_ch)) - curr_res = curr_res // 2 - - # non-local attention block - blocks.append(ResBlock(block_in_ch, block_in_ch)) - blocks.append(AttnBlock(block_in_ch)) - blocks.append(ResBlock(block_in_ch, block_in_ch)) - - # normalise and convert to latent size - blocks.append(normalize(block_in_ch)) - blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) - self.blocks = nn.ModuleList(blocks) - - def forward(self, x): - for block in self.blocks: - x = block(x) - - return x - - -class Generator(nn.Module): - def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): - super().__init__() - self.nf = nf - self.ch_mult = ch_mult - self.num_resolutions = len(self.ch_mult) - self.num_res_blocks = res_blocks - self.resolution = img_size - self.attn_resolutions = attn_resolutions - self.in_channels = emb_dim - self.out_channels = 3 - block_in_ch = self.nf * self.ch_mult[-1] - curr_res = self.resolution // 2 ** (self.num_resolutions-1) - - blocks = [] - # initial conv - blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) - - # non-local attention block - blocks.append(ResBlock(block_in_ch, block_in_ch)) - blocks.append(AttnBlock(block_in_ch)) - blocks.append(ResBlock(block_in_ch, block_in_ch)) - - for i in reversed(range(self.num_resolutions)): - block_out_ch = self.nf * self.ch_mult[i] - - for _ in range(self.num_res_blocks): - blocks.append(ResBlock(block_in_ch, block_out_ch)) - block_in_ch = block_out_ch - - if curr_res in self.attn_resolutions: - blocks.append(AttnBlock(block_in_ch)) - - if i != 0: - blocks.append(Upsample(block_in_ch)) - curr_res = curr_res * 2 - - blocks.append(normalize(block_in_ch)) - blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) - - self.blocks = nn.ModuleList(blocks) - - - def forward(self, x): - for block in self.blocks: - x = block(x) - - return x - - -@ARCH_REGISTRY.register() -class VQAutoEncoder(nn.Module): - def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=None, codebook_size=1024, emb_dim=256, - beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): - super().__init__() - logger = get_root_logger() - self.in_channels = 3 - self.nf = nf - self.n_blocks = res_blocks - self.codebook_size = codebook_size - self.embed_dim = emb_dim - self.ch_mult = ch_mult - self.resolution = img_size - self.attn_resolutions = attn_resolutions or [16] - self.quantizer_type = quantizer - self.encoder = Encoder( - self.in_channels, - self.nf, - self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, - self.attn_resolutions - ) - if self.quantizer_type == "nearest": - self.beta = beta #0.25 - self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) - elif self.quantizer_type == "gumbel": - self.gumbel_num_hiddens = emb_dim - self.straight_through = gumbel_straight_through - self.kl_weight = gumbel_kl_weight - self.quantize = GumbelQuantizer( - self.codebook_size, - self.embed_dim, - self.gumbel_num_hiddens, - self.straight_through, - self.kl_weight - ) - self.generator = Generator( - self.nf, - self.embed_dim, - self.ch_mult, - self.n_blocks, - self.resolution, - self.attn_resolutions - ) - - if model_path is not None: - chkpt = torch.load(model_path, map_location='cpu') - if 'params_ema' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) - logger.info(f'vqgan is loaded from: {model_path} [params_ema]') - elif 'params' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) - logger.info(f'vqgan is loaded from: {model_path} [params]') - else: - raise ValueError('Wrong params!') - - - def forward(self, x): - x = self.encoder(x) - quant, codebook_loss, quant_stats = self.quantize(x) - x = self.generator(quant) - return x, codebook_loss, quant_stats - - - -# patch based discriminator -@ARCH_REGISTRY.register() -class VQGANDiscriminator(nn.Module): - def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): - super().__init__() - - layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] - ndf_mult = 1 - ndf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - ndf_mult_prev = ndf_mult - ndf_mult = min(2 ** n, 8) - layers += [ - nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), - nn.BatchNorm2d(ndf * ndf_mult), - nn.LeakyReLU(0.2, True) - ] - - ndf_mult_prev = ndf_mult - ndf_mult = min(2 ** n_layers, 8) - - layers += [ - nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), - nn.BatchNorm2d(ndf * ndf_mult), - nn.LeakyReLU(0.2, True) - ] - - layers += [ - nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map - self.main = nn.Sequential(*layers) - - if model_path is not None: - chkpt = torch.load(model_path, map_location='cpu') - if 'params_d' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) - elif 'params' in chkpt: - self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) - else: - raise ValueError('Wrong params!') - - def forward(self, x): - return self.main(x) diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index da42b5e9..517eadfd 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -8,9 +8,6 @@ import modules.shared from modules import shared, devices, modelloader, errors from modules.paths import models_path -# codeformer people made a choice to include modified basicsr library to their project which makes -# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN. -# I am making a choice to include some files from codeformer to work around this issue. model_dir = "Codeformer" model_path = os.path.join(models_path, model_dir) model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' @@ -18,115 +15,127 @@ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codef codeformer = None -def setup_model(dirname): - os.makedirs(model_path, exist_ok=True) - - path = modules.paths.paths.get("CodeFormer", None) - if path is None: - return - - try: +class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): + def name(self): + return "CodeFormer" + + def __init__(self, dirname): + self.net = None + self.face_helper = None + self.cmd_dir = dirname + + def create_models(self): + from facexlib.detection import retinaface + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + + if self.net is not None and self.face_helper is not None: + self.net.to(devices.device_codeformer) + return self.net, self.face_helper + model_paths = modelloader.load_models( + model_path, + model_url, + self.cmd_dir, + download_name='codeformer-v0.1.0.pth', + ext_filter=['.pth'], + ) + + if len(model_paths) != 0: + ckpt_path = model_paths[0] + else: + print("Unable to load codeformer model.") + return None, None + net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer) + + if hasattr(retinaface, 'device'): + retinaface.device = devices.device_codeformer + + face_helper = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + use_parse=True, + device=devices.device_codeformer, + ) + + self.net = net + self.face_helper = face_helper + + def send_model_to(self, device): + self.net.to(device) + self.face_helper.face_det.to(device) + self.face_helper.face_parse.to(device) + + def restore(self, np_image, w=None): from torchvision.transforms.functional import normalize - from modules.codeformer.codeformer_arch import CodeFormer from basicsr.utils import img2tensor, tensor2img - from facelib.utils.face_restoration_helper import FaceRestoreHelper - from facelib.detection.retinaface import retinaface - - net_class = CodeFormer - - class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): - def name(self): - return "CodeFormer" - - def __init__(self, dirname): - self.net = None - self.face_helper = None - self.cmd_dir = dirname - - def create_models(self): - - if self.net is not None and self.face_helper is not None: - self.net.to(devices.device_codeformer) - return self.net, self.face_helper - model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth', ext_filter=['.pth']) - if len(model_paths) != 0: - ckpt_path = model_paths[0] - else: - print("Unable to load codeformer model.") - return None, None - net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) - checkpoint = torch.load(ckpt_path)['params_ema'] - net.load_state_dict(checkpoint) - net.eval() + np_image = np_image[:, :, ::-1] - if hasattr(retinaface, 'device'): - retinaface.device = devices.device_codeformer - face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=devices.device_codeformer) + original_resolution = np_image.shape[0:2] - self.net = net - self.face_helper = face_helper + self.create_models() + if self.net is None or self.face_helper is None: + return np_image - return net, face_helper + self.send_model_to(devices.device_codeformer) - def send_model_to(self, device): - self.net.to(device) - self.face_helper.face_det.to(device) - self.face_helper.face_parse.to(device) + self.face_helper.clean_all() + self.face_helper.read_image(np_image) + self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + self.face_helper.align_warp_face() - def restore(self, np_image, w=None): - np_image = np_image[:, :, ::-1] + for cropped_face in self.face_helper.cropped_faces: + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) - original_resolution = np_image.shape[0:2] + try: + with torch.no_grad(): + res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True) + if isinstance(res, tuple): + output = res[0] + else: + output = res + if not isinstance(res, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(res)}") + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + devices.torch_gc() + except Exception: + errors.report('Failed inference for CodeFormer', exc_info=True) + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) - self.create_models() - if self.net is None or self.face_helper is None: - return np_image + restored_face = restored_face.astype('uint8') + self.face_helper.add_restored_face(restored_face) - self.send_model_to(devices.device_codeformer) + self.face_helper.get_inverse_affine(None) - self.face_helper.clean_all() - self.face_helper.read_image(np_image) - self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) - self.face_helper.align_warp_face() + restored_img = self.face_helper.paste_faces_to_input_image() + restored_img = restored_img[:, :, ::-1] - for cropped_face in self.face_helper.cropped_faces: - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) + if original_resolution != restored_img.shape[0:2]: + restored_img = cv2.resize( + restored_img, + (0, 0), + fx=original_resolution[1]/restored_img.shape[1], + fy=original_resolution[0]/restored_img.shape[0], + interpolation=cv2.INTER_LINEAR, + ) - try: - with torch.no_grad(): - output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0] - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) - del output - devices.torch_gc() - except Exception: - errors.report('Failed inference for CodeFormer', exc_info=True) - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + self.face_helper.clean_all() - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) + if shared.opts.face_restoration_unload: + self.send_model_to(devices.cpu) - self.face_helper.get_inverse_affine(None) + return restored_img - restored_img = self.face_helper.paste_faces_to_input_image() - restored_img = restored_img[:, :, ::-1] - - if original_resolution != restored_img.shape[0:2]: - restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR) - - self.face_helper.clean_all() - - if shared.opts.face_restoration_unload: - self.send_model_to(devices.cpu) - - return restored_img +def setup_model(dirname): + os.makedirs(model_path, exist_ok=True) + try: global codeformer codeformer = FaceRestorerCodeFormer(dirname) shared.face_restorers.append(codeformer) - except Exception: errors.report("Error setting up CodeFormer", exc_info=True) - - # sys.path = stored_sys_path diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index c0d22a99..a7c7c9e3 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,122 +1,9 @@ -import sys - -import torch - -import modules.esrgan_model_arch as arch -from modules import modelloader, devices +from modules import modelloader, devices, errors from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model -def mod2normal(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - if 'conv_first.weight' in state_dict: - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] - crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] - crt_net['model.3.weight'] = state_dict['upconv1.weight'] - crt_net['model.3.bias'] = state_dict['upconv1.bias'] - crt_net['model.6.weight'] = state_dict['upconv2.weight'] - crt_net['model.6.bias'] = state_dict['upconv2.bias'] - crt_net['model.8.weight'] = state_dict['HRconv.weight'] - crt_net['model.8.bias'] = state_dict['HRconv.bias'] - crt_net['model.10.weight'] = state_dict['conv_last.weight'] - crt_net['model.10.bias'] = state_dict['conv_last.bias'] - state_dict = crt_net - return state_dict - - -def resrgan2normal(state_dict, nb=23): - # this code is copied from https://github.com/victorca25/iNNfer - if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: - re8x = 0 - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if "rdb" in k: - ori_k = k.replace('body.', 'model.1.sub.') - ori_k = ori_k.replace('.rdb', '.RDB') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight'] - crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias'] - crt_net['model.3.weight'] = state_dict['conv_up1.weight'] - crt_net['model.3.bias'] = state_dict['conv_up1.bias'] - crt_net['model.6.weight'] = state_dict['conv_up2.weight'] - crt_net['model.6.bias'] = state_dict['conv_up2.bias'] - - if 'conv_up3.weight' in state_dict: - # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py - re8x = 3 - crt_net['model.9.weight'] = state_dict['conv_up3.weight'] - crt_net['model.9.bias'] = state_dict['conv_up3.bias'] - - crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight'] - crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias'] - crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight'] - crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias'] - - state_dict = crt_net - return state_dict - - -def infer_params(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - scale2x = 0 - scalemin = 6 - n_uplayer = 0 - plus = False - - for block in list(state_dict): - parts = block.split(".") - n_parts = len(parts) - if n_parts == 5 and parts[2] == "sub": - nb = int(parts[3]) - elif n_parts == 3: - part_num = int(parts[1]) - if (part_num > scalemin - and parts[0] == "model" - and parts[2] == "weight"): - scale2x += 1 - if part_num > n_uplayer: - n_uplayer = part_num - out_nc = state_dict[block].shape[0] - if not plus and "conv1x1" in block: - plus = True - - nf = state_dict["model.0.weight"].shape[0] - in_nc = state_dict["model.0.weight"].shape[1] - out_nc = out_nc - scale = 2 ** scale2x - - return in_nc, out_nc, nf, nb, plus, scale - - class UpscalerESRGAN(Upscaler): def __init__(self, dirname): self.name = "ESRGAN" @@ -142,12 +29,11 @@ class UpscalerESRGAN(Upscaler): def do_upscale(self, img, selected_model): try: model = self.load_model(selected_model) - except Exception as e: - print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) + except Exception: + errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True) return img model.to(devices.device_esrgan) - img = esrgan_upscale(model, img) - return img + return esrgan_upscale(model, img) def load_model(self, path: str): if path.startswith("http"): @@ -160,33 +46,10 @@ class UpscalerESRGAN(Upscaler): else: filename = path - state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - - if "params_ema" in state_dict: - state_dict = state_dict["params_ema"] - elif "params" in state_dict: - state_dict = state_dict["params"] - num_conv = 16 if "realesr-animevideov3" in filename else 32 - model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu') - model.load_state_dict(state_dict) - model.eval() - return model - - if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict: - nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23 - state_dict = resrgan2normal(state_dict, nb) - elif "conv_first.weight" in state_dict: - state_dict = mod2normal(state_dict) - elif "model.0.weight" not in state_dict: - raise Exception("The file is not a recognized ESRGAN model.") - - in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) - - model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus) - model.load_state_dict(state_dict) - model.eval() - - return model + return modelloader.load_spandrel_model( + filename, + device=('cpu' if devices.device_esrgan.type == 'mps' else None), + ) def esrgan_upscale(model, img): diff --git a/modules/esrgan_model_arch.py b/modules/esrgan_model_arch.py deleted file mode 100644 index 2b9888ba..00000000 --- a/modules/esrgan_model_arch.py +++ /dev/null @@ -1,465 +0,0 @@ -# this file is adapted from https://github.com/victorca25/iNNfer - -from collections import OrderedDict -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - - -#################### -# RRDBNet Generator -#################### - -class RRDBNet(nn.Module): - def __init__(self, in_nc, out_nc, nf, nb, nr=3, gc=32, upscale=4, norm_type=None, - act_type='leakyrelu', mode='CNA', upsample_mode='upconv', convtype='Conv2D', - finalact=None, gaussian_noise=False, plus=False): - super(RRDBNet, self).__init__() - n_upscale = int(math.log(upscale, 2)) - if upscale == 3: - n_upscale = 1 - - self.resrgan_scale = 0 - if in_nc % 16 == 0: - self.resrgan_scale = 1 - elif in_nc != 4 and in_nc % 4 == 0: - self.resrgan_scale = 2 - - fea_conv = conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) - rb_blocks = [RRDB(nf, nr, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=norm_type, act_type=act_type, mode='CNA', convtype=convtype, - gaussian_noise=gaussian_noise, plus=plus) for _ in range(nb)] - LR_conv = conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode, convtype=convtype) - - if upsample_mode == 'upconv': - upsample_block = upconv_block - elif upsample_mode == 'pixelshuffle': - upsample_block = pixelshuffle_block - else: - raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found') - if upscale == 3: - upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype) - else: - upsampler = [upsample_block(nf, nf, act_type=act_type, convtype=convtype) for _ in range(n_upscale)] - HR_conv0 = conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type, convtype=convtype) - HR_conv1 = conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None, convtype=convtype) - - outact = act(finalact) if finalact else None - - self.model = sequential(fea_conv, ShortcutBlock(sequential(*rb_blocks, LR_conv)), - *upsampler, HR_conv0, HR_conv1, outact) - - def forward(self, x, outm=None): - if self.resrgan_scale == 1: - feat = pixel_unshuffle(x, scale=4) - elif self.resrgan_scale == 2: - feat = pixel_unshuffle(x, scale=2) - else: - feat = x - - return self.model(feat) - - -class RRDB(nn.Module): - """ - Residual in Residual Dense Block - (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) - """ - - def __init__(self, nf, nr=3, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', - spectral_norm=False, gaussian_noise=False, plus=False): - super(RRDB, self).__init__() - # This is for backwards compatibility with existing models - if nr == 3: - self.RDB1 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - self.RDB2 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - self.RDB3 = ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) - else: - RDB_list = [ResidualDenseBlock_5C(nf, kernel_size, gc, stride, bias, pad_type, - norm_type, act_type, mode, convtype, spectral_norm=spectral_norm, - gaussian_noise=gaussian_noise, plus=plus) for _ in range(nr)] - self.RDBs = nn.Sequential(*RDB_list) - - def forward(self, x): - if hasattr(self, 'RDB1'): - out = self.RDB1(x) - out = self.RDB2(out) - out = self.RDB3(out) - else: - out = self.RDBs(x) - return out * 0.2 + x - - -class ResidualDenseBlock_5C(nn.Module): - """ - Residual Dense Block - The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) - Modified options that can be used: - - "Partial Convolution based Padding" arXiv:1811.11718 - - "Spectral normalization" arXiv:1802.05957 - - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. - {Rakotonirina} and A. {Rasoanaivo} - """ - - def __init__(self, nf=64, kernel_size=3, gc=32, stride=1, bias=1, pad_type='zero', - norm_type=None, act_type='leakyrelu', mode='CNA', convtype='Conv2D', - spectral_norm=False, gaussian_noise=False, plus=False): - super(ResidualDenseBlock_5C, self).__init__() - - self.noise = GaussianNoise() if gaussian_noise else None - self.conv1x1 = conv1x1(nf, gc) if plus else None - - self.conv1 = conv_block(nf, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv2 = conv_block(nf+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv3 = conv_block(nf+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - self.conv4 = conv_block(nf+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=act_type, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - if mode == 'CNA': - last_act = None - else: - last_act = act_type - self.conv5 = conv_block(nf+4*gc, nf, 3, stride, bias=bias, pad_type=pad_type, - norm_type=norm_type, act_type=last_act, mode=mode, convtype=convtype, - spectral_norm=spectral_norm) - - def forward(self, x): - x1 = self.conv1(x) - x2 = self.conv2(torch.cat((x, x1), 1)) - if self.conv1x1: - x2 = x2 + self.conv1x1(x) - x3 = self.conv3(torch.cat((x, x1, x2), 1)) - x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) - if self.conv1x1: - x4 = x4 + x2 - x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - if self.noise: - return self.noise(x5.mul(0.2) + x) - else: - return x5 * 0.2 + x - - -#################### -# ESRGANplus -#################### - -class GaussianNoise(nn.Module): - def __init__(self, sigma=0.1, is_relative_detach=False): - super().__init__() - self.sigma = sigma - self.is_relative_detach = is_relative_detach - self.noise = torch.tensor(0, dtype=torch.float) - - def forward(self, x): - if self.training and self.sigma != 0: - self.noise = self.noise.to(x.device) - scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x - sampled_noise = self.noise.repeat(*x.size()).normal_() * scale - x = x + sampled_noise - return x - -def conv1x1(in_planes, out_planes, stride=1): - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) - - -#################### -# SRVGGNetCompact -#################### - -class SRVGGNetCompact(nn.Module): - """A compact VGG-style network structure for super-resolution. - This class is copied from https://github.com/xinntao/Real-ESRGAN - """ - - def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): - super(SRVGGNetCompact, self).__init__() - self.num_in_ch = num_in_ch - self.num_out_ch = num_out_ch - self.num_feat = num_feat - self.num_conv = num_conv - self.upscale = upscale - self.act_type = act_type - - self.body = nn.ModuleList() - # the first conv - self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) - # the first activation - if act_type == 'relu': - activation = nn.ReLU(inplace=True) - elif act_type == 'prelu': - activation = nn.PReLU(num_parameters=num_feat) - elif act_type == 'leakyrelu': - activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.body.append(activation) - - # the body structure - for _ in range(num_conv): - self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) - # activation - if act_type == 'relu': - activation = nn.ReLU(inplace=True) - elif act_type == 'prelu': - activation = nn.PReLU(num_parameters=num_feat) - elif act_type == 'leakyrelu': - activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) - self.body.append(activation) - - # the last conv - self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) - # upsample - self.upsampler = nn.PixelShuffle(upscale) - - def forward(self, x): - out = x - for i in range(0, len(self.body)): - out = self.body[i](out) - - out = self.upsampler(out) - # add the nearest upsampled image, so that the network learns the residual - base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') - out += base - return out - - -#################### -# Upsampler -#################### - -class Upsample(nn.Module): - r"""Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data. - The input data is assumed to be of the form - `minibatch x channels x [optional depth] x [optional height] x width`. - """ - - def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None): - super(Upsample, self).__init__() - if isinstance(scale_factor, tuple): - self.scale_factor = tuple(float(factor) for factor in scale_factor) - else: - self.scale_factor = float(scale_factor) if scale_factor else None - self.mode = mode - self.size = size - self.align_corners = align_corners - - def forward(self, x): - return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) - - def extra_repr(self): - if self.scale_factor is not None: - info = f'scale_factor={self.scale_factor}' - else: - info = f'size={self.size}' - info += f', mode={self.mode}' - return info - - -def pixel_unshuffle(x, scale): - """ Pixel unshuffle. - Args: - x (Tensor): Input feature with shape (b, c, hh, hw). - scale (int): Downsample ratio. - Returns: - Tensor: the pixel unshuffled feature. - """ - b, c, hh, hw = x.size() - out_channel = c * (scale**2) - assert hh % scale == 0 and hw % scale == 0 - h = hh // scale - w = hw // scale - x_view = x.view(b, c, h, scale, w, scale) - return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) - - -def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', convtype='Conv2D'): - """ - Pixel shuffle layer - (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional - Neural Network, CVPR17) - """ - conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, - pad_type=pad_type, norm_type=None, act_type=None, convtype=convtype) - pixel_shuffle = nn.PixelShuffle(upscale_factor) - - n = norm(norm_type, out_nc) if norm_type else None - a = act(act_type) if act_type else None - return sequential(conv, pixel_shuffle, n, a) - - -def upconv_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', mode='nearest', convtype='Conv2D'): - """ Upconv layer """ - upscale_factor = (1, upscale_factor, upscale_factor) if convtype == 'Conv3D' else upscale_factor - upsample = Upsample(scale_factor=upscale_factor, mode=mode) - conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, - pad_type=pad_type, norm_type=norm_type, act_type=act_type, convtype=convtype) - return sequential(upsample, conv) - - - - - - - - -#################### -# Basic blocks -#################### - - -def make_layer(basic_block, num_basic_block, **kwarg): - """Make layers by stacking the same blocks. - Args: - basic_block (nn.module): nn.module class for basic block. (block) - num_basic_block (int): number of blocks. (n_layers) - Returns: - nn.Sequential: Stacked blocks in nn.Sequential. - """ - layers = [] - for _ in range(num_basic_block): - layers.append(basic_block(**kwarg)) - return nn.Sequential(*layers) - - -def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0): - """ activation helper """ - act_type = act_type.lower() - if act_type == 'relu': - layer = nn.ReLU(inplace) - elif act_type in ('leakyrelu', 'lrelu'): - layer = nn.LeakyReLU(neg_slope, inplace) - elif act_type == 'prelu': - layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) - elif act_type == 'tanh': # [-1, 1] range output - layer = nn.Tanh() - elif act_type == 'sigmoid': # [0, 1] range output - layer = nn.Sigmoid() - else: - raise NotImplementedError(f'activation layer [{act_type}] is not found') - return layer - - -class Identity(nn.Module): - def __init__(self, *kwargs): - super(Identity, self).__init__() - - def forward(self, x, *kwargs): - return x - - -def norm(norm_type, nc): - """ Return a normalization layer """ - norm_type = norm_type.lower() - if norm_type == 'batch': - layer = nn.BatchNorm2d(nc, affine=True) - elif norm_type == 'instance': - layer = nn.InstanceNorm2d(nc, affine=False) - elif norm_type == 'none': - def norm_layer(x): return Identity() - else: - raise NotImplementedError(f'normalization layer [{norm_type}] is not found') - return layer - - -def pad(pad_type, padding): - """ padding layer helper """ - pad_type = pad_type.lower() - if padding == 0: - return None - if pad_type == 'reflect': - layer = nn.ReflectionPad2d(padding) - elif pad_type == 'replicate': - layer = nn.ReplicationPad2d(padding) - elif pad_type == 'zero': - layer = nn.ZeroPad2d(padding) - else: - raise NotImplementedError(f'padding layer [{pad_type}] is not implemented') - return layer - - -def get_valid_padding(kernel_size, dilation): - kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) - padding = (kernel_size - 1) // 2 - return padding - - -class ShortcutBlock(nn.Module): - """ Elementwise sum the output of a submodule to its input """ - def __init__(self, submodule): - super(ShortcutBlock, self).__init__() - self.sub = submodule - - def forward(self, x): - output = x + self.sub(x) - return output - - def __repr__(self): - return 'Identity + \n|' + self.sub.__repr__().replace('\n', '\n|') - - -def sequential(*args): - """ Flatten Sequential. It unwraps nn.Sequential. """ - if len(args) == 1: - if isinstance(args[0], OrderedDict): - raise NotImplementedError('sequential does not support OrderedDict input.') - return args[0] # No sequential is needed. - modules = [] - for module in args: - if isinstance(module, nn.Sequential): - for submodule in module.children(): - modules.append(submodule) - elif isinstance(module, nn.Module): - modules.append(module) - return nn.Sequential(*modules) - - -def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, - pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D', - spectral_norm=False): - """ Conv layer with padding, normalization, activation """ - assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]' - padding = get_valid_padding(kernel_size, dilation) - p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None - padding = padding if pad_type == 'zero' else 0 - - if convtype=='PartialConv2D': - from torchvision.ops import PartialConv2d # this is definitely not going to work, but PartialConv2d doesn't work anyway and this shuts up static analyzer - c = PartialConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - elif convtype=='DeformConv2D': - from torchvision.ops import DeformConv2d # not tested - c = DeformConv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - elif convtype=='Conv3D': - c = nn.Conv3d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - else: - c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, bias=bias, groups=groups) - - if spectral_norm: - c = nn.utils.spectral_norm(c) - - a = act(act_type) if act_type else None - if 'CNA' in mode: - n = norm(norm_type, out_nc) if norm_type else None - return sequential(p, c, n, a) - elif mode == 'NAC': - if norm_type is None and act_type is not None: - a = act(act_type, inplace=False) - n = norm(norm_type, in_nc) if norm_type else None - return sequential(n, a, p, c) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 01d668ec..6b6f17c4 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,8 +1,5 @@ import os -import facexlib -import gfpgan - import modules.face_restoration from modules import paths, shared, devices, modelloader, errors @@ -41,6 +38,8 @@ def gfpgann(): print("Unable to load gfpgan model!") return None + import facexlib.detection.retinaface + if hasattr(facexlib.detection.retinaface, 'device'): facexlib.detection.retinaface.device = devices.device_gfpgan model_file_path = model_file @@ -81,8 +80,10 @@ gfpgan_constructor = None def setup_model(dirname): try: os.makedirs(model_path, exist_ok=True) - from gfpgan import GFPGANer - from facexlib import detection, parsing # noqa: F401 + import gfpgan + import facexlib.detection + import facexlib.parsing + global user_path global have_gfpgan global gfpgan_constructor @@ -111,7 +112,7 @@ def setup_model(dirname): facexlib.parsing.load_file_from_url = facex_load_file_from_url2 user_path = dirname have_gfpgan = True - gfpgan_constructor = GFPGANer + gfpgan_constructor = gfpgan.GFPGANer class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): def name(self): diff --git a/modules/launch_utils.py b/modules/launch_utils.py index dabef0f5..c2cbd8ce 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -345,13 +345,11 @@ def prepare_environment(): stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') - codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "ab527a9a6d347f364e3d185ba6d714e22d80cb3c") - codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") try: @@ -408,15 +406,10 @@ def prepare_environment(): git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) - git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) startup_timer.record("clone repositores") - if not is_installed("lpips"): - run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer") - startup_timer.record("install CodeFormer requirements") - if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) diff --git a/modules/modelloader.py b/modules/modelloader.py index 098bcb79..30116932 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os import shutil import importlib @@ -10,6 +11,9 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale from modules.paths import script_path, models_path +logger = logging.getLogger(__name__) + + def load_file_from_url( url: str, *, @@ -177,3 +181,15 @@ def load_upscalers(): # Special case for UpscalerNone keeps it at the beginning of the list. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) + + +def load_spandrel_model(path, *, device, half: bool = False, dtype=None): + import spandrel + model = spandrel.ModelLoader(device=device).load_from_file(path) + if half: + model = model.model.half() + if dtype: + model = model.model.to(dtype=dtype) + model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) + return model diff --git a/modules/paths.py b/modules/paths.py index 187b9496..03064651 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -38,7 +38,6 @@ mute_sdxl_imports() path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), - (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), ] diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 02841c30..332d8f4b 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,9 +1,6 @@ import os -import numpy as np -from PIL import Image -from realesrgan import RealESRGANer - +from modules.upscaler_utils import upscale_with_model from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts from modules import modelloader, errors @@ -14,29 +11,20 @@ class UpscalerRealESRGAN(Upscaler): self.name = "RealESRGAN" self.user_path = path super().__init__() - try: - from basicsr.archs.rrdbnet_arch import RRDBNet # noqa: F401 - from realesrgan import RealESRGANer # noqa: F401 - from realesrgan.archs.srvgg_arch import SRVGGNetCompact # noqa: F401 - self.enable = True - self.scalers = [] - scalers = self.load_models(path) + self.enable = True + self.scalers = [] + scalers = get_realesrgan_models(self) - local_model_paths = self.find_models(ext_filter=[".pth"]) - for scaler in scalers: - if scaler.local_data_path.startswith("http"): - filename = modelloader.friendly_name(scaler.local_data_path) - local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] - if local_model_candidates: - scaler.local_data_path = local_model_candidates[0] + local_model_paths = self.find_models(ext_filter=[".pth"]) + for scaler in scalers: + if scaler.local_data_path.startswith("http"): + filename = modelloader.friendly_name(scaler.local_data_path) + local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")] + if local_model_candidates: + scaler.local_data_path = local_model_candidates[0] - if scaler.name in opts.realesrgan_enabled_models: - self.scalers.append(scaler) - - except Exception: - errors.report("Error importing Real-ESRGAN", exc_info=True) - self.enable = False - self.scalers = [] + if scaler.name in opts.realesrgan_enabled_models: + self.scalers.append(scaler) def do_upscale(self, img, path): if not self.enable: @@ -48,20 +36,18 @@ class UpscalerRealESRGAN(Upscaler): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img - upsampler = RealESRGANer( - scale=info.scale, - model_path=info.local_data_path, - model=info.model(), - half=not cmd_opts.no_half and not cmd_opts.upcast_sampling, - tile=opts.ESRGAN_tile, - tile_pad=opts.ESRGAN_tile_overlap, + mod = modelloader.load_spandrel_model( + info.local_data_path, device=self.device, + half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + ) + return upscale_with_model( + mod, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + # TODO: `outscale`? ) - - upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] - - image = Image.fromarray(upsampled) - return image def load_model(self, path): for scaler in self.scalers: @@ -76,58 +62,43 @@ class UpscalerRealESRGAN(Upscaler): return scaler raise ValueError(f"Unable to find model info: {path}") - def load_models(self, _): - return get_realesrgan_models(self) - -def get_realesrgan_models(scaler): - try: - from basicsr.archs.rrdbnet_arch import RRDBNet - from realesrgan.archs.srvgg_arch import SRVGGNetCompact - models = [ - UpscalerData( - name="R-ESRGAN General 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN General WDN 4xV3", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN AnimeVideo", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", - scale=4, - upscaler=scaler, - model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') - ), - UpscalerData( - name="R-ESRGAN 4x+", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", - scale=4, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) - ), - UpscalerData( - name="R-ESRGAN 4x+ Anime6B", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", - scale=4, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - ), - UpscalerData( - name="R-ESRGAN 2x+", - path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", - scale=2, - upscaler=scaler, - model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) - ), - ] - return models - except Exception: - errors.report("Error making Real-ESRGAN models list", exc_info=True) +def get_realesrgan_models(scaler: UpscalerRealESRGAN): + return [ + UpscalerData( + name="R-ESRGAN General 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN General WDN 4xV3", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN AnimeVideo", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 4x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 4x+ Anime6B", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", + scale=4, + upscaler=scaler, + ), + UpscalerData( + name="R-ESRGAN 2x+", + path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", + scale=2, + upscaler=scaler, + ), + ] diff --git a/modules/sysinfo.py b/modules/sysinfo.py index b669edd0..5abf616b 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -26,11 +26,9 @@ environment_whitelist = { "OPENCLIP_PACKAGE", "STABLE_DIFFUSION_REPO", "K_DIFFUSION_REPO", - "CODEFORMER_REPO", "BLIP_REPO", "STABLE_DIFFUSION_COMMIT_HASH", "K_DIFFUSION_COMMIT_HASH", - "CODEFORMER_COMMIT_HASH", "BLIP_COMMIT_HASH", "COMMANDLINE_ARGS", "IGNORE_CMD_ARGS_ERRORS", diff --git a/modules/upscaler.py b/modules/upscaler.py index b256e085..3aee69db 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -98,6 +98,9 @@ class UpscalerData: self.scale = scale self.model = model + def __repr__(self): + return f"" + class UpscalerNone(Upscaler): name = "None" diff --git a/requirements.txt b/requirements.txt index 80b43845..36f5674a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ basicsr blendmodes clean-fid einops +facexlib fastapi>=0.90.1 gfpgan gradio==3.41.2 @@ -20,13 +21,11 @@ open-clip-torch piexif psutil pytorch_lightning -realesrgan requests resize-right safetensors scikit-image>=0.19 -timm tomesd torch torchdiffeq diff --git a/requirements_versions.txt b/requirements_versions.txt index cb7403a9..042fa708 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -5,6 +5,7 @@ basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1 +facexlib==0.3.0 fastapi==0.94.0 gfpgan==1.3.8 gradio==3.41.2 @@ -19,11 +20,10 @@ open-clip-torch==2.20.0 piexif==1.1.3 psutil==5.9.5 pytorch_lightning==1.9.4 -realesrgan==0.3.0 resize-right==0.0.2 safetensors==0.3.1 scikit-image==0.21.0 -timm==0.9.2 +spandrel==0.1.6 tomesd==0.1.3 torch torchdiffeq==0.2.3 -- cgit v1.2.3 From 4ad0c0c0a805da4bac03cff86ea17c25a1291546 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 16:37:03 +0200 Subject: Verify architecture for loaded Spandrel models --- extensions-builtin/ScuNET/scripts/scunet_model.py | 2 +- extensions-builtin/SwinIR/scripts/swinir_model.py | 1 + modules/codeformer_model.py | 1 + modules/esrgan_model.py | 1 + modules/gfpgan_model.py | 1 + modules/hat_model.py | 1 + modules/modelloader.py | 13 ++++++++++++- modules/realesrgan_model.py | 7 ++++--- 8 files changed, 22 insertions(+), 5 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 18cf8e1a..5f3dd08b 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - return modelloader.load_spandrel_model(filename, device=device) + return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet') def on_ui_settings(): diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 85c18b9e..aae159af 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler): filename, device=self._get_device(), dtype=devices.dtype, + expected_architecture="SwinIR", ) if getattr(opts, 'SWIN_torch_compile', False): try: diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ceda4bab..44b84618 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): return modelloader.load_spandrel_model( model_path, device=devices.device_codeformer, + expected_architecture='CodeFormer', ).model raise ValueError("No codeformer model found") diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a7c7c9e3..70041ab0 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler): return modelloader.load_spandrel_model( filename, device=('cpu' if devices.device_esrgan.type == 'mps' else None), + expected_architecture='ESRGAN', ) diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index a356b56f..48f8ad5e 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): net = modelloader.load_spandrel_model( model_path, device=self.get_device(), + expected_architecture='GFPGAN', ).model net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 return net diff --git a/modules/hat_model.py b/modules/hat_model.py index 553e1941..7f2abb41 100644 --- a/modules/hat_model.py +++ b/modules/hat_model.py @@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler): return modelloader.load_spandrel_model( path, device=devices.device_esrgan, # TODO: should probably be device_hat + expected_architecture='HAT', ) diff --git a/modules/modelloader.py b/modules/modelloader.py index 30116932..f4182559 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,6 +6,8 @@ import shutil import importlib from urllib.parse import urlparse +import torch + from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path @@ -183,9 +185,18 @@ def load_upscalers(): ) -def load_spandrel_model(path, *, device, half: bool = False, dtype=None): +def load_spandrel_model( + path: str, + *, + device: str | torch.device | None, + half: bool = False, + dtype: str | None = None, + expected_architecture: str | None = None, +): import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model.architecture != expected_architecture: + raise TypeError(f"Model {path} is not a {expected_architecture} model") if half: model = model.model.half() if dtype: diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 332d8f4b..2a2be5ad 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,9 +1,9 @@ import os -from modules.upscaler_utils import upscale_with_model -from modules.upscaler import Upscaler, UpscalerData -from modules.shared import cmd_opts, opts from modules import modelloader, errors +from modules.shared import cmd_opts, opts +from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model class UpscalerRealESRGAN(Upscaler): @@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler): info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), + expected_architecture="RealESRGAN", ) return upscale_with_model( mod, -- cgit v1.2.3 From bc5ae74c7d8949bab37e260b16e76889b9968099 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sat, 30 Dec 2023 21:52:27 +0100 Subject: Added negative prompts to extra networks lora --- extensions-builtin/Lora/ui_edit_user_metadata.py | 14 ++++++++-- extensions-builtin/Lora/ui_extra_networks_lora.py | 9 +++++++ javascript/extraNetworks.js | 31 +++++++++++++++-------- modules/ui_extra_networks.py | 5 +++- 4 files changed, 46 insertions(+), 13 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index c7011909..f7859b21 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,12 +54,14 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight + user_metadata["negative text"] = negative_text + user_metadata["negative weight"] = negative_weight user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -127,6 +129,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), + user_metadata.get('negative text', ''), + float(user_metadata.get('negative weight', 0.0)), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -162,7 +166,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.taginfo = gr.HighlightedText(label="Training dataset tags") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) - + self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") + self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -198,6 +203,8 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.taginfo, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, + self.slider_negative_weight, row_random_prompt, random_prompt, ] @@ -211,7 +218,10 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.select_sd_version, self.edit_activation_text, self.slider_preferred_weight, + self.edit_negative_text, + self.slider_negative_weight, self.edit_notes, ] + self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index df02c663..09ce2a05 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -45,6 +45,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) + negative_prompt = item["user_metadata"].get("negative text") + preferred_negative_weight = item["user_metadata"].get("negative weight") + item["negative_prompt"] = quote_js("") + if negative_prompt: + neg_prompt = negative_prompt + if (preferred_negative_weight > 0): + neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' + item["negative_prompt"] = quote_js(neg_prompt) + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 98a7abb7..2bb9795d 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -185,8 +185,10 @@ onUiLoaded(setupExtraNetworks); var re_extranet = /<([^:^>]+:[^:]+):[\d.]+>(.*)/; var re_extranet_g = /<([^:^>]+:[^:]+):[\d.]+>/g; -function tryToRemoveExtraNetworkFromPrompt(textarea, text) { - var m = text.match(re_extranet); +var re_extranet_neg = /\(([^:^>]+:[\d.]+)\)/; +var re_extranet_g_neg = /\(([^:^>]+:[\d.]+)\)/g; +function tryToRemoveExtraNetworkFromPrompt(textarea, text, isNeg) { + var m = text.match(isNeg ? re_extranet_neg : re_extranet); var replaced = false; var newTextareaText; if (m) { @@ -194,8 +196,8 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { var extraTextAfterNet = m[2]; var partToSearch = m[1]; var foundAtPosition = -1; - newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, net, pos) { - m = found.match(re_extranet); + newTextareaText = textarea.value.replaceAll(isNeg ? re_extranet_g_neg : re_extranet_g, function(found, net, pos) { + m = found.match(isNeg ? re_extranet_neg : re_extranet); if (m[1] == partToSearch) { replaced = true; foundAtPosition = pos; @@ -205,7 +207,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { }); if (foundAtPosition >= 0) { - if (newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { + if (extraTextAfterNet && newTextareaText.substr(foundAtPosition, extraTextAfterNet.length) == extraTextAfterNet) { newTextareaText = newTextareaText.substr(0, foundAtPosition) + newTextareaText.substr(foundAtPosition + extraTextAfterNet.length); } if (newTextareaText.substr(foundAtPosition - extraTextBeforeNet.length, extraTextBeforeNet.length) == extraTextBeforeNet) { @@ -230,14 +232,23 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text) { return false; } -function cardClicked(tabname, textToAdd, allowNegativePrompt) { - var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); +function updatePromptArea(text, textArea, isNeg) { - if (!tryToRemoveExtraNetworkFromPrompt(textarea, textToAdd)) { - textarea.value = textarea.value + opts.extra_networks_add_text_separator + textToAdd; + if (!tryToRemoveExtraNetworkFromPrompt(textArea, text, isNeg)) { + textArea.value = textArea.value + opts.extra_networks_add_text_separator + text; } - updateInput(textarea); + updateInput(textArea); +} + +function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) { + if (textToAddNegative.length > 0) { + updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")) + updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true) + } else { + var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); + updatePromptArea(textToAdd, textarea) + } } function saveCardPreview(event, tabname, filename) { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index fe5d3ba3..b8c02241 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -223,7 +223,10 @@ class ExtraNetworksPage: onclick = item.get("onclick", None) if onclick is None: - onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + if "negative_prompt" in item: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {item["negative_prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"' + else: + onclick = '"' + html.escape(f"""return cardClicked({quote_js(tabname)}, {item["prompt"]}, {'""'}, {"true" if self.allow_negative_prompt else "false"})""") + '"' height = f"height: {shared.opts.extra_networks_card_height}px;" if shared.opts.extra_networks_card_height else '' width = f"width: {shared.opts.extra_networks_card_width}px;" if shared.opts.extra_networks_card_width else '' -- cgit v1.2.3 From a2f23f9d22dde87bf2529dcb2854a6a5d3d44278 Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sat, 30 Dec 2023 22:16:51 +0100 Subject: Code Style fixes --- extensions-builtin/Lora/ui_extra_networks_lora.py | 4 ++-- javascript/extraNetworks.js | 6 +++--- modules/upscaler_utils.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 09ce2a05..9a6624e3 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -52,8 +52,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): neg_prompt = negative_prompt if (preferred_negative_weight > 0): neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' - item["negative_prompt"] = quote_js(neg_prompt) - + item["negative_prompt"] = quote_js(neg_prompt) + sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: item["sd_version"] = sd_version diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 2bb9795d..f1ad19a6 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -243,11 +243,11 @@ function updatePromptArea(text, textArea, isNeg) { function cardClicked(tabname, textToAdd, textToAddNegative, allowNegativePrompt) { if (textToAddNegative.length > 0) { - updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")) - updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true) + updatePromptArea(textToAdd, gradioApp().querySelector("#" + tabname + "_prompt > label > textarea")); + updatePromptArea(textToAddNegative, gradioApp().querySelector("#" + tabname + "_neg_prompt > label > textarea"), true); } else { var textarea = allowNegativePrompt ? activePromptTextarea[tabname] : gradioApp().querySelector("#" + tabname + "_prompt > label > textarea"); - updatePromptArea(textToAdd, textarea) + updatePromptArea(textToAdd, textarea); } } diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 39f78a0b..1d610dbf 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__) -- cgit v1.2.3 From 777af661a21821994993df3ef566b01df2bb61a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:09:51 +0200 Subject: Be more clear about Spandrel model nomenclature --- extensions-builtin/SwinIR/scripts/swinir_model.py | 6 +++--- modules/gfpgan_model.py | 10 ++++++---- modules/modelloader.py | 2 +- modules/realesrgan_model.py | 4 ++-- modules/upscaler_utils.py | 2 +- 5 files changed, 13 insertions(+), 11 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index aae159af..95c7ec64 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -71,7 +71,7 @@ class UpscalerSwinIR(Upscaler): else: filename = path - model = modelloader.load_spandrel_model( + model_descriptor = modelloader.load_spandrel_model( filename, device=self._get_device(), dtype=devices.dtype, @@ -79,10 +79,10 @@ class UpscalerSwinIR(Upscaler): ) if getattr(opts, 'SWIN_torch_compile', False): try: - model = torch.compile(model) + model_descriptor.model.compile() except Exception: logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) - return model + return model_descriptor def _get_device(self): return devices.get_device_for('swinir') diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 48f8ad5e..445b0409 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging import os +import torch + from modules import ( devices, errors, @@ -25,7 +27,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): def get_device(self): return devices.device_gfpgan - def load_net(self) -> None: + def load_net(self) -> torch.Module: for model_path in modelloader.load_models( model_path=self.model_path, model_url=model_url, @@ -34,13 +36,13 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): ext_filter=['.pth'], ): if 'GFPGAN' in os.path.basename(model_path): - net = modelloader.load_spandrel_model( + model = modelloader.load_spandrel_model( model_path, device=self.get_device(), expected_architecture='GFPGAN', ).model - net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 - return net + model.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81 + return model raise ValueError("No GFPGAN model found") def restore(self, np_image): diff --git a/modules/modelloader.py b/modules/modelloader.py index 8bcee08c..a7194137 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -143,7 +143,7 @@ def load_spandrel_model( *, device: str | torch.device | None, half: bool = False, - dtype: str | None = None, + dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 65f2e880..4d35b695 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -36,14 +36,14 @@ class UpscalerRealESRGAN(Upscaler): errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img - mod = modelloader.load_spandrel_model( + model_descriptor = modelloader.load_spandrel_model( info.local_data_path, device=self.device, half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), expected_architecture="ESRGAN", # "RealESRGAN" isn't a specific thing for Spandrel ) return upscale_with_model( - mod, + model_descriptor, img, tile_size=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index dde5d7ad..174c9bc3 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__) -- cgit v1.2.3 From 6f86b62a1be7993073ba3a789d522e0b8870605a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 22:53:49 +0200 Subject: Deduplicate tiled inference code from SwinIR/ScuNET --- extensions-builtin/ScuNET/scripts/scunet_model.py | 55 ++++------------- extensions-builtin/SwinIR/scripts/swinir_model.py | 57 ++---------------- modules/upscaler_utils.py | 72 ++++++++++++++++++++++- 3 files changed, 87 insertions(+), 97 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 5f3dd08b..f799cb76 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -3,12 +3,11 @@ import sys import PIL.Image import numpy as np import torch -from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors - from modules.shared import opts +from modules.upscaler_utils import tiled_upscale_2 class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,47 +39,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers.append(scaler_data2) self.scalers = scalers - @staticmethod - @torch.no_grad() - def tiled_inference(img, model): - # test the image tile by tile - h, w = img.shape[2:] - tile = opts.SCUNET_tile - tile_overlap = opts.SCUNET_tile_overlap - if tile == 0: - return model(img) - - device = devices.get_device_for('scunet') - assert tile % 8 == 0, "tile size should be a multiple of window_size" - sf = 1 - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: - for h_idx in h_idx_list: - - for w_idx in w_idx_list: - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - def do_upscale(self, img: PIL.Image.Image, selected_file): devices.torch_gc() @@ -104,7 +62,16 @@ class UpscalerScuNET(modules.upscaler.Upscaler): _img[:, :, :h, :w] = torch_img # pad image torch_img = _img - torch_output = self.tiled_inference(torch_img, model).squeeze(0) + with torch.no_grad(): + torch_output = tiled_upscale_2( + torch_img, + model, + tile_size=opts.SCUNET_tile, + tile_overlap=opts.SCUNET_tile_overlap, + scale=1, + device=devices.get_device_for('scunet'), + desc="ScuNET tiles", + ).squeeze(0) torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() del torch_img, torch_output diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 95c7ec64..8a555c79 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -4,11 +4,11 @@ import sys import numpy as np import torch from PIL import Image -from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts, state +from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -110,14 +110,14 @@ def upscale( w_pad = (w_old // window_size + 1) * window_size - w_old img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference( + output = tiled_upscale_2( img, model, - tile=tile, + tile_size=tile, tile_overlap=tile_overlap, - window_size=window_size, scale=scale, device=device, + desc="SwinIR tiles", ) output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() @@ -129,53 +129,6 @@ def upscale( return Image.fromarray(output, "RGB") -def inference( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size: int, - scale: int, - device, -): - # test the image tile by tile - b, c, h, w = img.size() - tile = min(tile, h, w) - assert tile % window_size == 0, "tile size should be a multiple of window_size" - sf = scale - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: - for h_idx in h_idx_list: - if state.interrupted or state.skipped: - break - - for w_idx in w_idx_list: - if state.interrupted or state.skipped: - break - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - - def on_ui_settings(): import gradio as gr diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 174c9bc3..8e413854 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images +from modules import images, shared logger = logging.getLogger(__name__) @@ -68,3 +68,73 @@ def upscale_with_model( overlap=grid.overlap * scale_factor, ) return images.combine_grid(newgrid) + + +def tiled_upscale_2( + img, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + device, + desc="Tiled upscale", +): + # Alternative implementation of `upscale_with_model` originally used by + # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and + # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in + # Pillow space without weighting. + b, c, h, w = img.size() + tile_size = min(tile_size, h, w) + + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img.shape) + return model(img) + + stride = tile_size - tile_overlap + h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size] + w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size] + result = torch.zeros( + b, + c, + h * scale, + w * scale, + device=device, + ).type_as(img) + weights = torch.zeros_like(result) + logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + for h_idx in h_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + for w_idx in w_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + in_patch = img[ + ..., + h_idx : h_idx + tile_size, + w_idx : w_idx + tile_size, + ] + out_patch = model(in_patch) + + result[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch) + + out_patch_mask = torch.ones_like(out_patch) + + weights[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch_mask) + + pbar.update(1) + + output = result.div_(weights) + + return output -- cgit v1.2.3 From d4945f4422e5a0bf31a6dbe4c1aeedd78c09eacb Mon Sep 17 00:00:00 2001 From: Learwin <6223515+Learwin@users.noreply.github.com> Date: Sun, 31 Dec 2023 13:22:30 +0100 Subject: Removed weight slider for negative prompts --- extensions-builtin/Lora/ui_edit_user_metadata.py | 7 +------ extensions-builtin/Lora/ui_extra_networks_lora.py | 6 +----- 2 files changed, 2 insertions(+), 11 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index f7859b21..3160aecf 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,14 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, negative_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight user_metadata["negative text"] = negative_text - user_metadata["negative weight"] = negative_weight user_metadata["notes"] = notes self.write_user_metadata(name, user_metadata) @@ -130,7 +129,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), user_metadata.get('negative text', ''), - float(user_metadata.get('negative weight', 0.0)), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -167,7 +165,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts") - self.slider_negative_weight = gr.Slider(label='Preferred negative weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) with gr.Row() as row_random_prompt: with gr.Column(scale=8): random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) @@ -204,7 +201,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_activation_text, self.slider_preferred_weight, self.edit_negative_text, - self.slider_negative_weight, row_random_prompt, random_prompt, ] @@ -219,7 +215,6 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_activation_text, self.slider_preferred_weight, self.edit_negative_text, - self.slider_negative_weight, self.edit_notes, ] diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 9a6624e3..e714fac4 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -46,13 +46,9 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): item["prompt"] += " + " + quote_js(" " + activation_text) negative_prompt = item["user_metadata"].get("negative text") - preferred_negative_weight = item["user_metadata"].get("negative weight") item["negative_prompt"] = quote_js("") if negative_prompt: - neg_prompt = negative_prompt - if (preferred_negative_weight > 0): - neg_prompt = '(' + negative_prompt + ':' + str(preferred_negative_weight) + ')' - item["negative_prompt"] = quote_js(neg_prompt) + item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)') sd_version = item["user_metadata"].get("sd version") if sd_version in network.SdVersion.__members__: -- cgit v1.2.3 From d859cec696a953dbfd6f69f7735e68661748d579 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 13:53:12 +0300 Subject: infotext.py: rename usages in the codebase --- .../extra-options-section/scripts/extra_options_section.py | 4 ++-- modules/api/api.py | 10 +++++----- modules/img2img.py | 2 +- modules/postprocessing.py | 4 ++-- modules/processing.py | 4 ++-- modules/processing_scripts/refiner.py | 2 +- modules/processing_scripts/seed.py | 2 +- modules/shared_items.py | 4 ++-- modules/txt2img.py | 2 +- modules/ui.py | 4 ++-- modules/ui_common.py | 4 ++-- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_user_metadata.py | 4 ++-- modules/ui_postprocessing.py | 2 +- 14 files changed, 25 insertions(+), 25 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index ac2c3de4..8aa901fd 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math import gradio as gr -from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste +from modules import scripts, shared, ui_components, ui_settings, infotext from modules.ui_components import FormColumn @@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script): extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") - mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} + mapping = {k: v for v, k in infotext.infotext_to_setting_name_mapping} with gr.Blocks() as interface: with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): diff --git a/modules/api/api.py b/modules/api/api.py index 843c59b0..0e2807de 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -369,9 +369,9 @@ class Api: if not request.infotext: return {} - possible_fields = generation_parameters_copypaste.paste_fields[tabname]["fields"] + possible_fields = infotext.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this - params = generation_parameters_copypaste.parse_generation_parameters(request.infotext) + params = infotext.parse_generation_parameters(request.infotext) def get_field_value(field, params): value = field.function(params) if field.function else params.get(field.label) @@ -408,7 +408,7 @@ class Api: if request.override_settings is None: request.override_settings = {} - overriden_settings = generation_parameters_copypaste.get_override_settings(params) + overriden_settings = infotext.get_override_settings(params) for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value @@ -584,7 +584,7 @@ class Api: if geninfo is None: geninfo = "" - params = generation_parameters_copypaste.parse_generation_parameters(geninfo) + params = infotext.parse_generation_parameters(geninfo) script_callbacks.infotext_pasted_callback(geninfo, params) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) diff --git a/modules/img2img.py b/modules/img2img.py index c583290a..75b3d346 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr import gradio as gr from modules import images as imgutil -from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters +from modules.infotext import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state from modules.sd_models import get_closet_checkpoint_match diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 0c59fad4..f776f7b6 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ import os from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext from modules.shared import opts @@ -86,7 +86,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, basename = '' forced_filename = None - infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None]) + infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None]) if opts.enable_pnginfo: pp.image.info = existing_pnginfo diff --git a/modules/processing.py b/modules/processing.py index 7789f9a4..b30df60d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from skimage import exposure from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -733,7 +733,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "User": p.user if opts.add_user_name_to_info else None, } - generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index cefad32b..e9941413 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,7 +1,7 @@ import gradio as gr from modules import scripts, sd_models -from modules.generation_parameters_copypaste import PasteField +from modules.infotext import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index a3e16a12..60293278 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,7 +3,7 @@ import json import gradio as gr from modules import scripts, ui, errors -from modules.generation_parameters_copypaste import PasteField +from modules.infotext import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton diff --git a/modules/shared_items.py b/modules/shared_items.py index 991971ad..e1392472 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -67,14 +67,14 @@ def reload_hypernetworks(): def get_infotext_names(): - from modules import generation_parameters_copypaste, shared + from modules import infotext, shared res = {} for info in shared.opts.data_labels.values(): if info.infotext: res[info.infotext] = 1 - for tab_data in generation_parameters_copypaste.paste_fields.values(): + for tab_data in infotext.paste_fields.values(): for _, name in tab_data.get("fields") or []: if isinstance(name, str): res[name] = 1 diff --git a/modules/txt2img.py b/modules/txt2img.py index e4e18ceb..3a481915 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -2,7 +2,7 @@ from contextlib import closing import modules.scripts from modules import processing -from modules.generation_parameters_copypaste import create_override_settings_dict +from modules.infotext import create_override_settings_dict from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html diff --git a/modules/ui.py b/modules/ui.py index 9db2407e..6451e14c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,14 +21,14 @@ from modules.ui_gradio_extensions import reload_javascript from modules.shared import opts, cmd_opts -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste import modules.hypernetworks.ui as hypernetworks_ui import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.generation_parameters_copypaste import image_from_url_text, PasteField +from modules.infotext import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component diff --git a/modules/ui_common.py b/modules/ui_common.py index 032ec4af..fd32676f 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -8,10 +8,10 @@ import gradio as gr import subprocess as sp from modules import call_queue, shared -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text import modules.images from modules.ui_components import ToolButton -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index b8c02241..790af135 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ import json import html from fastapi.exceptions import HTTPException -from modules.generation_parameters_copypaste import image_from_url_text +from modules.infotext import image_from_url_text from modules.ui_components import ToolButton extra_pages = [] diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 36a807fc..87aeb6f3 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -5,7 +5,7 @@ import os.path import gradio as gr -from modules import generation_parameters_copypaste, images, sysinfo, errors, ui_extra_networks +from modules import infotext, images, sysinfo, errors, ui_extra_networks class UserMetadataEditor: @@ -181,7 +181,7 @@ class UserMetadataEditor: index = len(gallery) - 1 if index >= len(gallery) else index img_info = gallery[index if index >= 0 else 0] - image = generation_parameters_copypaste.image_from_url_text(img_info) + image = infotext.image_from_url_text(img_info) geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index 13d888e4..b74a1532 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,6 +1,6 @@ import gradio as gr from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow -import modules.generation_parameters_copypaste as parameters_copypaste +import modules.infotext as parameters_copypaste def create_ui(): -- cgit v1.2.3 From 5d7d1823afab0a051a3fbbdb3213bae8051350b7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 17:25:30 +0300 Subject: rename infotext.py again, this time to infotext_utils.py; I didn't realize infotext would be used for variable names in multiple places, which makes it awkward to import the module; also fix the bug I caused by this rename that breaks tests --- .../scripts/extra_options_section.py | 4 +- modules/api/api.py | 10 +- modules/img2img.py | 2 +- modules/infotext.py | 502 --------------------- modules/infotext_utils.py | 502 +++++++++++++++++++++ modules/postprocessing.py | 4 +- modules/processing.py | 4 +- modules/processing_scripts/refiner.py | 2 +- modules/processing_scripts/seed.py | 2 +- modules/shared_items.py | 4 +- modules/txt2img.py | 2 +- modules/ui.py | 4 +- modules/ui_common.py | 4 +- modules/ui_extra_networks.py | 2 +- modules/ui_extra_networks_user_metadata.py | 4 +- modules/ui_postprocessing.py | 2 +- 16 files changed, 527 insertions(+), 527 deletions(-) delete mode 100644 modules/infotext.py create mode 100644 modules/infotext_utils.py (limited to 'extensions-builtin') diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index 8aa901fd..4c10d9c7 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math import gradio as gr -from modules import scripts, shared, ui_components, ui_settings, infotext +from modules import scripts, shared, ui_components, ui_settings, infotext_utils from modules.ui_components import FormColumn @@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script): extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img") - mapping = {k: v for v, k in infotext.infotext_to_setting_name_mapping} + mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping} with gr.Blocks() as interface: with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname): diff --git a/modules/api/api.py b/modules/api/api.py index 0e2807de..9d1292e9 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext, sd_models +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -369,9 +369,9 @@ class Api: if not request.infotext: return {} - possible_fields = infotext.paste_fields[tabname]["fields"] + possible_fields = infotext_utils.paste_fields[tabname]["fields"] set_fields = request.model_dump(exclude_unset=True) if hasattr(request, "request") else request.dict(exclude_unset=True) # pydantic v1/v2 have differenrt names for this - params = infotext.parse_generation_parameters(request.infotext) + params = infotext_utils.parse_generation_parameters(request.infotext) def get_field_value(field, params): value = field.function(params) if field.function else params.get(field.label) @@ -408,7 +408,7 @@ class Api: if request.override_settings is None: request.override_settings = {} - overriden_settings = infotext.get_override_settings(params) + overriden_settings = infotext_utils.get_override_settings(params) for _, setting_name, value in overriden_settings: if setting_name not in request.override_settings: request.override_settings[setting_name] = value @@ -584,7 +584,7 @@ class Api: if geninfo is None: geninfo = "" - params = infotext.parse_generation_parameters(geninfo) + params = infotext_utils.parse_generation_parameters(geninfo) script_callbacks.infotext_pasted_callback(geninfo, params) return models.PNGInfoResponse(info=geninfo, items=items, parameters=params) diff --git a/modules/img2img.py b/modules/img2img.py index e7e8e251..04de8e62 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -7,7 +7,7 @@ from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageErr import gradio as gr from modules import images as imgutil -from modules.infotext import create_override_settings_dict, parse_generation_parameters +from modules.infotext_utils import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state from modules.sd_models import get_closet_checkpoint_match diff --git a/modules/infotext.py b/modules/infotext.py deleted file mode 100644 index 26e9b949..00000000 --- a/modules/infotext.py +++ /dev/null @@ -1,502 +0,0 @@ -from __future__ import annotations -import base64 -import io -import json -import os -import re -import sys - -import gradio as gr -from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions -from PIL import Image - -sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name - -re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' -re_param = re.compile(re_param_code) -re_imagesize = re.compile(r"^(\d+)x(\d+)$") -re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") -type_of_gr_update = type(gr.update()) - - -class ParamBinding: - def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None): - self.paste_button = paste_button - self.tabname = tabname - self.source_text_component = source_text_component - self.source_image_component = source_image_component - self.source_tabname = source_tabname - self.override_settings_component = override_settings_component - self.paste_field_names = paste_field_names or [] - - -class PasteField(tuple): - def __new__(cls, component, target, *, api=None): - return super().__new__(cls, (component, target)) - - def __init__(self, component, target, *, api=None): - super().__init__() - - self.api = api - self.component = component - self.label = target if isinstance(target, str) else None - self.function = target if callable(target) else None - - -paste_fields: dict[str, dict] = {} -registered_param_bindings: list[ParamBinding] = [] - - -def reset(): - paste_fields.clear() - registered_param_bindings.clear() - - -def quote(text): - if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text): - return text - - return json.dumps(text, ensure_ascii=False) - - -def unquote(text): - if len(text) == 0 or text[0] != '"' or text[-1] != '"': - return text - - try: - return json.loads(text) - except Exception: - return text - - -def image_from_url_text(filedata): - if filedata is None: - return None - - if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False): - filedata = filedata[0] - - if type(filedata) == dict and filedata.get("is_file", False): - filename = filedata["name"] - is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename) - assert is_in_right_dir, 'trying to open image file outside of allowed directories' - - filename = filename.rsplit('?', 1)[0] - return Image.open(filename) - - if type(filedata) == list: - if len(filedata) == 0: - return None - - filedata = filedata[0] - - if filedata.startswith("data:image/png;base64,"): - filedata = filedata[len("data:image/png;base64,"):] - - filedata = base64.decodebytes(filedata.encode('utf-8')) - image = Image.open(io.BytesIO(filedata)) - return image - - -def add_paste_fields(tabname, init_img, fields, override_settings_component=None): - - if fields: - for i in range(len(fields)): - if not isinstance(fields[i], PasteField): - fields[i] = PasteField(*fields[i]) - - paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component} - - # backwards compatibility for existing extensions - import modules.ui - if tabname == 'txt2img': - modules.ui.txt2img_paste_fields = fields - elif tabname == 'img2img': - modules.ui.img2img_paste_fields = fields - - -def create_buttons(tabs_list): - buttons = {} - for tab in tabs_list: - buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab") - return buttons - - -def bind_buttons(buttons, send_image, send_generate_info): - """old function for backwards compatibility; do not use this, use register_paste_params_button""" - for tabname, button in buttons.items(): - source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None - source_tabname = send_generate_info if isinstance(send_generate_info, str) else None - - register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname)) - - -def register_paste_params_button(binding: ParamBinding): - registered_param_bindings.append(binding) - - -def connect_paste_params_buttons(): - for binding in registered_param_bindings: - destination_image_component = paste_fields[binding.tabname]["init_img"] - fields = paste_fields[binding.tabname]["fields"] - override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"] - - destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None) - destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) - - if binding.source_image_component and destination_image_component: - if isinstance(binding.source_image_component, gr.Gallery): - func = send_image_and_dimensions if destination_width_component else image_from_url_text - jsfunc = "extract_image_from_gallery" - else: - func = send_image_and_dimensions if destination_width_component else lambda x: x - jsfunc = None - - binding.paste_button.click( - fn=func, - _js=jsfunc, - inputs=[binding.source_image_component], - outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], - show_progress=False, - ) - - if binding.source_text_component is not None and fields is not None: - connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname) - - if binding.source_tabname is not None and fields is not None: - paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names - binding.paste_button.click( - fn=lambda *x: x, - inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], - outputs=[field for field, name in fields if name in paste_field_names], - show_progress=False, - ) - - binding.paste_button.click( - fn=None, - _js=f"switch_to_{binding.tabname}", - inputs=None, - outputs=None, - show_progress=False, - ) - - -def send_image_and_dimensions(x): - if isinstance(x, Image.Image): - img = x - else: - img = image_from_url_text(x) - - if shared.opts.send_size and isinstance(img, Image.Image): - w = img.width - h = img.height - else: - w = gr.update() - h = gr.update() - - return img, w, h - - -def restore_old_hires_fix_params(res): - """for infotexts that specify old First pass size parameter, convert it into - width, height, and hr scale""" - - firstpass_width = res.get('First pass size-1', None) - firstpass_height = res.get('First pass size-2', None) - - if shared.opts.use_old_hires_fix_width_height: - hires_width = int(res.get("Hires resize-1", 0)) - hires_height = int(res.get("Hires resize-2", 0)) - - if hires_width and hires_height: - res['Size-1'] = hires_width - res['Size-2'] = hires_height - return - - if firstpass_width is None or firstpass_height is None: - return - - firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height) - width = int(res.get("Size-1", 512)) - height = int(res.get("Size-2", 512)) - - if firstpass_width == 0 or firstpass_height == 0: - firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) - - res['Size-1'] = firstpass_width - res['Size-2'] = firstpass_height - res['Hires resize-1'] = width - res['Hires resize-2'] = height - - -def parse_generation_parameters(x: str): - """parses generation parameters string, the one you see in text field under the picture in UI: -``` -girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate -Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing -Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b -``` - - returns a dict with field values - """ - - res = {} - - prompt = "" - negative_prompt = "" - - done_with_prompt = False - - *lines, lastline = x.strip().split("\n") - if len(re_param.findall(lastline)) < 3: - lines.append(lastline) - lastline = '' - - for line in lines: - line = line.strip() - if line.startswith("Negative prompt:"): - done_with_prompt = True - line = line[16:].strip() - if done_with_prompt: - negative_prompt += ("" if negative_prompt == "" else "\n") + line - else: - prompt += ("" if prompt == "" else "\n") + line - - if shared.opts.infotext_styles != "Ignore": - found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt) - - if shared.opts.infotext_styles == "Apply": - res["Styles array"] = found_styles - elif shared.opts.infotext_styles == "Apply if any" and found_styles: - res["Styles array"] = found_styles - - res["Prompt"] = prompt - res["Negative prompt"] = negative_prompt - - for k, v in re_param.findall(lastline): - try: - if v[0] == '"' and v[-1] == '"': - v = unquote(v) - - m = re_imagesize.match(v) - if m is not None: - res[f"{k}-1"] = m.group(1) - res[f"{k}-2"] = m.group(2) - else: - res[k] = v - except Exception: - print(f"Error parsing \"{k}: {v}\"") - - # Missing CLIP skip means it was set to 1 (the default) - if "Clip skip" not in res: - res["Clip skip"] = "1" - - hypernet = res.get("Hypernet", None) - if hypernet is not None: - res["Prompt"] += f"""""" - - if "Hires resize-1" not in res: - res["Hires resize-1"] = 0 - res["Hires resize-2"] = 0 - - if "Hires sampler" not in res: - res["Hires sampler"] = "Use same sampler" - - if "Hires checkpoint" not in res: - res["Hires checkpoint"] = "Use same checkpoint" - - if "Hires prompt" not in res: - res["Hires prompt"] = "" - - if "Hires negative prompt" not in res: - res["Hires negative prompt"] = "" - - restore_old_hires_fix_params(res) - - # Missing RNG means the default was set, which is GPU RNG - if "RNG" not in res: - res["RNG"] = "GPU" - - if "Schedule type" not in res: - res["Schedule type"] = "Automatic" - - if "Schedule max sigma" not in res: - res["Schedule max sigma"] = 0 - - if "Schedule min sigma" not in res: - res["Schedule min sigma"] = 0 - - if "Schedule rho" not in res: - res["Schedule rho"] = 0 - - if "VAE Encoder" not in res: - res["VAE Encoder"] = "Full" - - if "VAE Decoder" not in res: - res["VAE Decoder"] = "Full" - - if "FP8 weight" not in res: - res["FP8 weight"] = "Disable" - - if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": - res["Cache FP16 weight for LoRA"] = False - - infotext_versions.backcompat(res) - - skip = set(shared.opts.infotext_skip_pasting) - res = {k: v for k, v in res.items() if k not in skip} - - return res - - -infotext_to_setting_name_mapping = [ - -] -"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead. -Example content: - -infotext_to_setting_name_mapping = [ - ('Conditional mask weight', 'inpainting_mask_weight'), - ('Model hash', 'sd_model_checkpoint'), - ('ENSD', 'eta_noise_seed_delta'), - ('Schedule type', 'k_sched_type'), -] -""" - - -def create_override_settings_dict(text_pairs): - """creates processing's override_settings parameters from gradio's multiselect - - Example input: - ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337'] - - Example output: - {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337} - """ - - res = {} - - params = {} - for pair in text_pairs: - k, v = pair.split(":", maxsplit=1) - - params[k] = v.strip() - - mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] - for param_name, setting_name in mapping + infotext_to_setting_name_mapping: - value = params.get(param_name, None) - - if value is None: - continue - - res[setting_name] = shared.opts.cast_value(setting_name, value) - - return res - - -def get_override_settings(params, *, skip_fields=None): - """Returns a list of settings overrides from the infotext parameters dictionary. - - This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns - a list of tuples containing the parameter name, setting name, and new value cast to correct type. - - It checks for conditions before adding an override: - - ignores settings that match the current value - - ignores parameter keys present in skip_fields argument. - - Example input: - {"Clip skip": "2"} - - Example output: - [("Clip skip", "CLIP_stop_at_last_layers", 2)] - """ - - res = [] - - mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] - for param_name, setting_name in mapping + infotext_to_setting_name_mapping: - if param_name in (skip_fields or {}): - continue - - v = params.get(param_name, None) - if v is None: - continue - - if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: - continue - - v = shared.opts.cast_value(setting_name, v) - current_value = getattr(shared.opts, setting_name, None) - - if v == current_value: - continue - - res.append((param_name, setting_name, v)) - - return res - - -def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): - def paste_func(prompt): - if not prompt and not shared.cmd_opts.hide_ui_dir_config: - filename = os.path.join(data_path, "params.txt") - if os.path.exists(filename): - with open(filename, "r", encoding="utf8") as file: - prompt = file.read() - - params = parse_generation_parameters(prompt) - script_callbacks.infotext_pasted_callback(prompt, params) - res = [] - - for output, key in paste_fields: - if callable(key): - v = key(params) - else: - v = params.get(key, None) - - if v is None: - res.append(gr.update()) - elif isinstance(v, type_of_gr_update): - res.append(v) - else: - try: - valtype = type(output.value) - - if valtype == bool and v == "False": - val = False - else: - val = valtype(v) - - res.append(gr.update(value=val)) - except Exception: - res.append(gr.update()) - - return res - - if override_settings_component is not None: - already_handled_fields = {key: 1 for _, key in paste_fields} - - def paste_settings(params): - vals = get_override_settings(params, skip_fields=already_handled_fields) - - vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals] - - return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) - - paste_fields = paste_fields + [(override_settings_component, paste_settings)] - - button.click( - fn=paste_func, - inputs=[input_comp], - outputs=[x[0] for x in paste_fields], - show_progress=False, - ) - button.click( - fn=None, - _js=f"recalculate_prompts_{tabname}", - inputs=[], - outputs=[], - show_progress=False, - ) - diff --git a/modules/infotext_utils.py b/modules/infotext_utils.py new file mode 100644 index 00000000..26e9b949 --- /dev/null +++ b/modules/infotext_utils.py @@ -0,0 +1,502 @@ +from __future__ import annotations +import base64 +import io +import json +import os +import re +import sys + +import gradio as gr +from modules.paths import data_path +from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions +from PIL import Image + +sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__] # alias for old name + +re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' +re_param = re.compile(re_param_code) +re_imagesize = re.compile(r"^(\d+)x(\d+)$") +re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") +type_of_gr_update = type(gr.update()) + + +class ParamBinding: + def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None): + self.paste_button = paste_button + self.tabname = tabname + self.source_text_component = source_text_component + self.source_image_component = source_image_component + self.source_tabname = source_tabname + self.override_settings_component = override_settings_component + self.paste_field_names = paste_field_names or [] + + +class PasteField(tuple): + def __new__(cls, component, target, *, api=None): + return super().__new__(cls, (component, target)) + + def __init__(self, component, target, *, api=None): + super().__init__() + + self.api = api + self.component = component + self.label = target if isinstance(target, str) else None + self.function = target if callable(target) else None + + +paste_fields: dict[str, dict] = {} +registered_param_bindings: list[ParamBinding] = [] + + +def reset(): + paste_fields.clear() + registered_param_bindings.clear() + + +def quote(text): + if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text): + return text + + return json.dumps(text, ensure_ascii=False) + + +def unquote(text): + if len(text) == 0 or text[0] != '"' or text[-1] != '"': + return text + + try: + return json.loads(text) + except Exception: + return text + + +def image_from_url_text(filedata): + if filedata is None: + return None + + if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False): + filedata = filedata[0] + + if type(filedata) == dict and filedata.get("is_file", False): + filename = filedata["name"] + is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename) + assert is_in_right_dir, 'trying to open image file outside of allowed directories' + + filename = filename.rsplit('?', 1)[0] + return Image.open(filename) + + if type(filedata) == list: + if len(filedata) == 0: + return None + + filedata = filedata[0] + + if filedata.startswith("data:image/png;base64,"): + filedata = filedata[len("data:image/png;base64,"):] + + filedata = base64.decodebytes(filedata.encode('utf-8')) + image = Image.open(io.BytesIO(filedata)) + return image + + +def add_paste_fields(tabname, init_img, fields, override_settings_component=None): + + if fields: + for i in range(len(fields)): + if not isinstance(fields[i], PasteField): + fields[i] = PasteField(*fields[i]) + + paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component} + + # backwards compatibility for existing extensions + import modules.ui + if tabname == 'txt2img': + modules.ui.txt2img_paste_fields = fields + elif tabname == 'img2img': + modules.ui.img2img_paste_fields = fields + + +def create_buttons(tabs_list): + buttons = {} + for tab in tabs_list: + buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab") + return buttons + + +def bind_buttons(buttons, send_image, send_generate_info): + """old function for backwards compatibility; do not use this, use register_paste_params_button""" + for tabname, button in buttons.items(): + source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None + source_tabname = send_generate_info if isinstance(send_generate_info, str) else None + + register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname)) + + +def register_paste_params_button(binding: ParamBinding): + registered_param_bindings.append(binding) + + +def connect_paste_params_buttons(): + for binding in registered_param_bindings: + destination_image_component = paste_fields[binding.tabname]["init_img"] + fields = paste_fields[binding.tabname]["fields"] + override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"] + + destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None) + destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None) + + if binding.source_image_component and destination_image_component: + if isinstance(binding.source_image_component, gr.Gallery): + func = send_image_and_dimensions if destination_width_component else image_from_url_text + jsfunc = "extract_image_from_gallery" + else: + func = send_image_and_dimensions if destination_width_component else lambda x: x + jsfunc = None + + binding.paste_button.click( + fn=func, + _js=jsfunc, + inputs=[binding.source_image_component], + outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component], + show_progress=False, + ) + + if binding.source_text_component is not None and fields is not None: + connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname) + + if binding.source_tabname is not None and fields is not None: + paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names + binding.paste_button.click( + fn=lambda *x: x, + inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names], + outputs=[field for field, name in fields if name in paste_field_names], + show_progress=False, + ) + + binding.paste_button.click( + fn=None, + _js=f"switch_to_{binding.tabname}", + inputs=None, + outputs=None, + show_progress=False, + ) + + +def send_image_and_dimensions(x): + if isinstance(x, Image.Image): + img = x + else: + img = image_from_url_text(x) + + if shared.opts.send_size and isinstance(img, Image.Image): + w = img.width + h = img.height + else: + w = gr.update() + h = gr.update() + + return img, w, h + + +def restore_old_hires_fix_params(res): + """for infotexts that specify old First pass size parameter, convert it into + width, height, and hr scale""" + + firstpass_width = res.get('First pass size-1', None) + firstpass_height = res.get('First pass size-2', None) + + if shared.opts.use_old_hires_fix_width_height: + hires_width = int(res.get("Hires resize-1", 0)) + hires_height = int(res.get("Hires resize-2", 0)) + + if hires_width and hires_height: + res['Size-1'] = hires_width + res['Size-2'] = hires_height + return + + if firstpass_width is None or firstpass_height is None: + return + + firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height) + width = int(res.get("Size-1", 512)) + height = int(res.get("Size-2", 512)) + + if firstpass_width == 0 or firstpass_height == 0: + firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) + + res['Size-1'] = firstpass_width + res['Size-2'] = firstpass_height + res['Hires resize-1'] = width + res['Hires resize-2'] = height + + +def parse_generation_parameters(x: str): + """parses generation parameters string, the one you see in text field under the picture in UI: +``` +girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate +Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing +Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b +``` + + returns a dict with field values + """ + + res = {} + + prompt = "" + negative_prompt = "" + + done_with_prompt = False + + *lines, lastline = x.strip().split("\n") + if len(re_param.findall(lastline)) < 3: + lines.append(lastline) + lastline = '' + + for line in lines: + line = line.strip() + if line.startswith("Negative prompt:"): + done_with_prompt = True + line = line[16:].strip() + if done_with_prompt: + negative_prompt += ("" if negative_prompt == "" else "\n") + line + else: + prompt += ("" if prompt == "" else "\n") + line + + if shared.opts.infotext_styles != "Ignore": + found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt) + + if shared.opts.infotext_styles == "Apply": + res["Styles array"] = found_styles + elif shared.opts.infotext_styles == "Apply if any" and found_styles: + res["Styles array"] = found_styles + + res["Prompt"] = prompt + res["Negative prompt"] = negative_prompt + + for k, v in re_param.findall(lastline): + try: + if v[0] == '"' and v[-1] == '"': + v = unquote(v) + + m = re_imagesize.match(v) + if m is not None: + res[f"{k}-1"] = m.group(1) + res[f"{k}-2"] = m.group(2) + else: + res[k] = v + except Exception: + print(f"Error parsing \"{k}: {v}\"") + + # Missing CLIP skip means it was set to 1 (the default) + if "Clip skip" not in res: + res["Clip skip"] = "1" + + hypernet = res.get("Hypernet", None) + if hypernet is not None: + res["Prompt"] += f"""""" + + if "Hires resize-1" not in res: + res["Hires resize-1"] = 0 + res["Hires resize-2"] = 0 + + if "Hires sampler" not in res: + res["Hires sampler"] = "Use same sampler" + + if "Hires checkpoint" not in res: + res["Hires checkpoint"] = "Use same checkpoint" + + if "Hires prompt" not in res: + res["Hires prompt"] = "" + + if "Hires negative prompt" not in res: + res["Hires negative prompt"] = "" + + restore_old_hires_fix_params(res) + + # Missing RNG means the default was set, which is GPU RNG + if "RNG" not in res: + res["RNG"] = "GPU" + + if "Schedule type" not in res: + res["Schedule type"] = "Automatic" + + if "Schedule max sigma" not in res: + res["Schedule max sigma"] = 0 + + if "Schedule min sigma" not in res: + res["Schedule min sigma"] = 0 + + if "Schedule rho" not in res: + res["Schedule rho"] = 0 + + if "VAE Encoder" not in res: + res["VAE Encoder"] = "Full" + + if "VAE Decoder" not in res: + res["VAE Decoder"] = "Full" + + if "FP8 weight" not in res: + res["FP8 weight"] = "Disable" + + if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable": + res["Cache FP16 weight for LoRA"] = False + + infotext_versions.backcompat(res) + + skip = set(shared.opts.infotext_skip_pasting) + res = {k: v for k, v in res.items() if k not in skip} + + return res + + +infotext_to_setting_name_mapping = [ + +] +"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead. +Example content: + +infotext_to_setting_name_mapping = [ + ('Conditional mask weight', 'inpainting_mask_weight'), + ('Model hash', 'sd_model_checkpoint'), + ('ENSD', 'eta_noise_seed_delta'), + ('Schedule type', 'k_sched_type'), +] +""" + + +def create_override_settings_dict(text_pairs): + """creates processing's override_settings parameters from gradio's multiselect + + Example input: + ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337'] + + Example output: + {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337} + """ + + res = {} + + params = {} + for pair in text_pairs: + k, v = pair.split(":", maxsplit=1) + + params[k] = v.strip() + + mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: + value = params.get(param_name, None) + + if value is None: + continue + + res[setting_name] = shared.opts.cast_value(setting_name, value) + + return res + + +def get_override_settings(params, *, skip_fields=None): + """Returns a list of settings overrides from the infotext parameters dictionary. + + This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns + a list of tuples containing the parameter name, setting name, and new value cast to correct type. + + It checks for conditions before adding an override: + - ignores settings that match the current value + - ignores parameter keys present in skip_fields argument. + + Example input: + {"Clip skip": "2"} + + Example output: + [("Clip skip", "CLIP_stop_at_last_layers", 2)] + """ + + res = [] + + mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext] + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: + if param_name in (skip_fields or {}): + continue + + v = params.get(param_name, None) + if v is None: + continue + + if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap: + continue + + v = shared.opts.cast_value(setting_name, v) + current_value = getattr(shared.opts, setting_name, None) + + if v == current_value: + continue + + res.append((param_name, setting_name, v)) + + return res + + +def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname): + def paste_func(prompt): + if not prompt and not shared.cmd_opts.hide_ui_dir_config: + filename = os.path.join(data_path, "params.txt") + if os.path.exists(filename): + with open(filename, "r", encoding="utf8") as file: + prompt = file.read() + + params = parse_generation_parameters(prompt) + script_callbacks.infotext_pasted_callback(prompt, params) + res = [] + + for output, key in paste_fields: + if callable(key): + v = key(params) + else: + v = params.get(key, None) + + if v is None: + res.append(gr.update()) + elif isinstance(v, type_of_gr_update): + res.append(v) + else: + try: + valtype = type(output.value) + + if valtype == bool and v == "False": + val = False + else: + val = valtype(v) + + res.append(gr.update(value=val)) + except Exception: + res.append(gr.update()) + + return res + + if override_settings_component is not None: + already_handled_fields = {key: 1 for _, key in paste_fields} + + def paste_settings(params): + vals = get_override_settings(params, skip_fields=already_handled_fields) + + vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals] + + return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) + + paste_fields = paste_fields + [(override_settings_component, paste_settings)] + + button.click( + fn=paste_func, + inputs=[input_comp], + outputs=[x[0] for x in paste_fields], + show_progress=False, + ) + button.click( + fn=None, + _js=f"recalculate_prompts_{tabname}", + inputs=[], + outputs=[], + show_progress=False, + ) + diff --git a/modules/postprocessing.py b/modules/postprocessing.py index facea899..7850328f 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -2,7 +2,7 @@ import os from PIL import Image -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common +from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, infotext_utils from modules.shared import opts @@ -86,7 +86,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, basename = '' forced_filename = None - infotext = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in pp.info.items() if v is not None]) + infotext = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in pp.info.items() if v is not None]) if opts.enable_pnginfo: pp.image.info = existing_pnginfo diff --git a/modules/processing.py b/modules/processing.py index f55b85ed..213a2879 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,7 @@ from skimage import exposure from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes @@ -746,7 +746,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "User": p.user if opts.add_user_name_to_info else None, } - generation_params_text = ", ".join([k if k == v else f'{k}: {infotext.quote(v)}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {infotext_utils.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.main_prompt if use_main_prompt else all_prompts[index] negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else "" diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index e9941413..ba33d8a4 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -1,7 +1,7 @@ import gradio as gr from modules import scripts, sd_models -from modules.infotext import PasteField +from modules.infotext_utils import PasteField from modules.ui_common import create_refresh_button from modules.ui_components import InputAccordion diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py index 60293278..2d3cbb97 100644 --- a/modules/processing_scripts/seed.py +++ b/modules/processing_scripts/seed.py @@ -3,7 +3,7 @@ import json import gradio as gr from modules import scripts, ui, errors -from modules.infotext import PasteField +from modules.infotext_utils import PasteField from modules.shared import cmd_opts from modules.ui_components import ToolButton diff --git a/modules/shared_items.py b/modules/shared_items.py index e1392472..13fb2814 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -67,14 +67,14 @@ def reload_hypernetworks(): def get_infotext_names(): - from modules import infotext, shared + from modules import infotext_utils, shared res = {} for info in shared.opts.data_labels.values(): if info.infotext: res[info.infotext] = 1 - for tab_data in infotext.paste_fields.values(): + for tab_data in infotext_utils.paste_fields.values(): for _, name in tab_data.get("fields") or []: if isinstance(name, str): res[name] = 1 diff --git a/modules/txt2img.py b/modules/txt2img.py index 3a481915..49660e89 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -2,7 +2,7 @@ from contextlib import closing import modules.scripts from modules import processing -from modules.infotext import create_override_settings_dict +from modules.infotext_utils import create_override_settings_dict from modules.shared import opts import modules.shared as shared from modules.ui import plaintext_to_html diff --git a/modules/ui.py b/modules/ui.py index 378529c7..52b15646 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,14 +21,14 @@ from modules.ui_gradio_extensions import reload_javascript from modules.shared import opts, cmd_opts -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste import modules.hypernetworks.ui as hypernetworks_ui import modules.textual_inversion.ui as textual_inversion_ui import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.infotext import image_from_url_text, PasteField +from modules.infotext_utils import image_from_url_text, PasteField create_setting_component = ui_settings.create_setting_component diff --git a/modules/ui_common.py b/modules/ui_common.py index fd32676f..f48ad426 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -8,10 +8,10 @@ import gradio as gr import subprocess as sp from modules import call_queue, shared -from modules.infotext import image_from_url_text +from modules.infotext_utils import image_from_url_text import modules.images from modules.ui_components import ToolButton -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 790af135..beea1316 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ import json import html from fastapi.exceptions import HTTPException -from modules.infotext import image_from_url_text +from modules.infotext_utils import image_from_url_text from modules.ui_components import ToolButton extra_pages = [] diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 87aeb6f3..989a649b 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -5,7 +5,7 @@ import os.path import gradio as gr -from modules import infotext, images, sysinfo, errors, ui_extra_networks +from modules import infotext_utils, images, sysinfo, errors, ui_extra_networks class UserMetadataEditor: @@ -181,7 +181,7 @@ class UserMetadataEditor: index = len(gallery) - 1 if index >= len(gallery) else index img_info = gallery[index if index >= 0 else 0] - image = infotext.image_from_url_text(img_info) + image = infotext_utils.image_from_url_text(img_info) geninfo, items = images.read_info_from_image(image) images.save_image_with_geninfo(image, geninfo, item["local_preview"]) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index b74a1532..1edb68c5 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -1,6 +1,6 @@ import gradio as gr from modules import scripts, shared, ui_common, postprocessing, call_queue, ui_toprow -import modules.infotext as parameters_copypaste +import modules.infotext_utils as parameters_copypaste def create_ui(): -- cgit v1.2.3 From cf14a6a7aaf8ccb40552990785d5c9e400d93610 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 16:11:18 +0200 Subject: Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right --- extensions-builtin/ScuNET/scripts/scunet_model.py | 48 +++--------- extensions-builtin/SwinIR/scripts/swinir_model.py | 62 ++-------------- modules/upscaler_utils.py | 89 ++++++++++++++++++----- 3 files changed, 87 insertions(+), 112 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index f799cb76..fe5e5a19 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,13 +1,9 @@ import sys import PIL.Image -import numpy as np -import torch import modules.upscaler -from modules import devices, modelloader, script_callbacks, errors -from modules.shared import opts -from modules.upscaler_utils import tiled_upscale_2 +from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,46 +36,23 @@ class UpscalerScuNET(modules.upscaler.Upscaler): self.scalers = scalers def do_upscale(self, img: PIL.Image.Image, selected_file): - devices.torch_gc() - try: model = self.load_model(selected_file) except Exception as e: print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) return img - device = devices.get_device_for('scunet') - tile = opts.SCUNET_tile - h, w = img.height, img.width - np_img = np.array(img) - np_img = np_img[:, :, ::-1] # RGB to BGR - np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW - torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore - - if tile > h or tile > w: - _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) - _img[:, :, :h, :w] = torch_img # pad image - torch_img = _img - - with torch.no_grad(): - torch_output = tiled_upscale_2( - torch_img, - model, - tile_size=opts.SCUNET_tile, - tile_overlap=opts.SCUNET_tile_overlap, - scale=1, - device=devices.get_device_for('scunet'), - desc="ScuNET tiles", - ).squeeze(0) - torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any - np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() - del torch_img, torch_output + img = upscaler_utils.upscale_2( + img, + model, + tile_size=shared.opts.SCUNET_tile, + tile_overlap=shared.opts.SCUNET_tile_overlap, + scale=1, # ScuNET is a denoising model, not an upscaler + desc='ScuNET', + ) devices.torch_gc() - - output = np_output.transpose((1, 2, 0)) # CHW to HWC - output = output[:, :, ::-1] # BGR to RGB - return PIL.Image.fromarray((output * 255).astype(np.uint8)) + return img def load_model(self, path: str): device = devices.get_device_for('scunet') @@ -93,7 +66,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def on_ui_settings(): import gradio as gr - from modules import shared shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 8a555c79..bc427fea 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,14 +1,10 @@ import logging import sys -import numpy as np -import torch from PIL import Image -from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts +from modules import devices, modelloader, script_callbacks, shared, upscaler_utils from modules.upscaler import Upscaler, UpscalerData -from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -36,9 +32,7 @@ class UpscalerSwinIR(Upscaler): self.scalers = scalers def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: - current_config = (model_file, opts.SWIN_tile) - - device = self._get_device() + current_config = (model_file, shared.opts.SWIN_tile) if self._cached_model_config == current_config: model = self._cached_model @@ -51,12 +45,13 @@ class UpscalerSwinIR(Upscaler): self._cached_model = model self._cached_model_config = current_config - img = upscale( + img = upscaler_utils.upscale_2( img, model, - tile=opts.SWIN_tile, - tile_overlap=opts.SWIN_tile_overlap, - device=device, + tile_size=shared.opts.SWIN_tile, + tile_overlap=shared.opts.SWIN_tile_overlap, + scale=4, # TODO: This was hard-coded before too... + desc="SwinIR", ) devices.torch_gc() return img @@ -77,7 +72,7 @@ class UpscalerSwinIR(Upscaler): dtype=devices.dtype, expected_architecture="SwinIR", ) - if getattr(opts, 'SWIN_torch_compile', False): + if getattr(shared.opts, 'SWIN_torch_compile', False): try: model_descriptor.model.compile() except Exception: @@ -88,47 +83,6 @@ class UpscalerSwinIR(Upscaler): return devices.get_device_for('swinir') -def upscale( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size=8, - scale=4, - device, -): - - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device, dtype=devices.dtype) - with torch.no_grad(), devices.autocast(): - _, _, h_old, w_old = img.size() - h_pad = (h_old // window_size + 1) * window_size - h_old - w_pad = (w_old // window_size + 1) * window_size - w_old - img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] - img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = tiled_upscale_2( - img, - model, - tile_size=tile, - tile_overlap=tile_overlap, - scale=scale, - device=device, - desc="SwinIR tiles", - ) - output = output[..., : h_old * scale, : w_old * scale] - output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - if output.ndim == 3: - output = np.transpose( - output[[2, 1, 0], :, :], (1, 2, 0) - ) # CHW-RGB to HCW-BGR - output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 - return Image.fromarray(output, "RGB") - - def on_ui_settings(): import gradio as gr diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 9379f512..e4c63f09 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -11,23 +11,40 @@ from modules import images, shared, torch_utils logger = logging.getLogger(__name__) -def upscale_without_tiling(model, img: Image.Image): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - +def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor: + img = np.array(img.convert("RGB")) + img = img[:, :, ::-1] # flip RGB to BGR + img = np.transpose(img, (2, 0, 1)) # HWC to CHW + img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1] + return torch.from_numpy(img) + + +def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: + if tensor.ndim == 4: + # If we're given a tensor with a batch dimension, squeeze it out + # (but only if it's a batch of size 1). + if tensor.shape[0] != 1: + raise ValueError(f"{tensor.shape} does not describe a BCHW tensor") + tensor = tensor.squeeze(0) + assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor" + # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? + arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp + arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale + arr = arr.astype(np.uint8) + arr = arr[:, :, ::-1] # flip BGR to RGB + return Image.fromarray(arr, "RGB") + + +def upscale_pil_patch(model, img: Image.Image) -> Image.Image: + """ + Upscale a given PIL image using the given model. + """ param = torch_utils.get_param(model) - img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): - output = model(img) - - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') + tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension + tensor = tensor.to(device=param.device, dtype=param.dtype) + return torch_bgr_to_pil_image(model(tensor)) def upscale_with_model( @@ -40,7 +57,7 @@ def upscale_with_model( ) -> Image.Image: if tile_size <= 0: logger.debug("Upscaling %s without tiling", img) - output = upscale_without_tiling(model, img) + output = upscale_pil_patch(model, img) logger.debug("=> %s", output) return output @@ -52,7 +69,7 @@ def upscale_with_model( newrow = [] for x, w, tile in row: logger.debug("Tile (%d, %d) %s...", x, y, tile) - output = upscale_without_tiling(model, tile) + output = upscale_pil_patch(model, tile) scale_factor = output.width // tile.width logger.debug("=> %s (scale factor %s)", output, scale_factor) newrow.append([x * scale_factor, w * scale_factor, output]) @@ -71,19 +88,22 @@ def upscale_with_model( def tiled_upscale_2( - img, + img: torch.Tensor, model, *, tile_size: int, tile_overlap: int, scale: int, - device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. + + # Grab the device the model is on, and use it. + device = torch_utils.get_param(model).device + b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -100,7 +120,8 @@ def tiled_upscale_2( h * scale, w * scale, device=device, - ).type_as(img) + dtype=img.dtype, + ) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: @@ -112,11 +133,13 @@ def tiled_upscale_2( if shared.state.interrupted or shared.state.skipped: break + # Only move this patch to the device if it's not already there. in_patch = img[ ..., h_idx : h_idx + tile_size, w_idx : w_idx + tile_size, - ] + ].to(device=device) + out_patch = model(in_patch) result[ @@ -138,3 +161,29 @@ def tiled_upscale_2( output = result.div_(weights) return output + + +def upscale_2( + img: Image.Image, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + desc: str, +): + """ + Convenience wrapper around `tiled_upscale_2` that handles PIL images. + """ + tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + + with torch.no_grad(): + output = tiled_upscale_2( + tensor, + model, + tile_size=tile_size, + tile_overlap=tile_overlap, + scale=scale, + desc=desc, + ) + return torch_bgr_to_pil_image(output) -- cgit v1.2.3 From dfdc51246c678b585e1bdfdb7d2f202b0ca0e362 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:38:13 +0200 Subject: SwinIR: use prefer_half --- extensions-builtin/SwinIR/scripts/swinir_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index bc427fea..6a8e21b0 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,6 +1,7 @@ import logging import sys +import torch from PIL import Image from modules import devices, modelloader, script_callbacks, shared, upscaler_utils @@ -69,7 +70,7 @@ class UpscalerSwinIR(Upscaler): model_descriptor = modelloader.load_spandrel_model( filename, device=self._get_device(), - dtype=devices.dtype, + prefer_half=(devices.dtype == torch.float16), expected_architecture="SwinIR", ) if getattr(shared.opts, 'SWIN_torch_compile', False): -- cgit v1.2.3 From 3d31d5c27beb433fa37b30f135ec06a278a87630 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:38:49 +0200 Subject: SwinIR: pass model.scale --- extensions-builtin/SwinIR/scripts/swinir_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 6a8e21b0..16bf9b79 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -51,7 +51,7 @@ class UpscalerSwinIR(Upscaler): model, tile_size=shared.opts.SWIN_tile, tile_overlap=shared.opts.SWIN_tile_overlap, - scale=4, # TODO: This was hard-coded before too... + scale=model.scale, desc="SwinIR", ) devices.torch_gc() -- cgit v1.2.3 From f8f38c7c28e48f9f79225c969e3e82b1adcfb910 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:31:48 +0800 Subject: Fix dtype casting for OFT module --- extensions-builtin/Lora/network_oft.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fa647020..342fcd0d 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + oft_blocks = self.oft_blocks.to(orig_weight.device) eye = torch.eye(self.block_size, device=self.oft_blocks.device) if self.is_kohya: @@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule): block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) - R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = oft_blocks.to(orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -77,6 +77,6 @@ class NetworkModuleOFT(network.NetworkModule): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) -- cgit v1.2.3 From 18ca987c92f52690daec43a6c67363c341bb6008 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:32:19 +0800 Subject: Add general forward method for all modules. --- extensions-builtin/Lora/network.py | 34 +++++++++++++++++++++++++++++++++- extensions-builtin/Lora/networks.py | 12 ++++++------ 2 files changed, 39 insertions(+), 7 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index a62e5eff..f9b571b5 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -3,6 +3,10 @@ import os from collections import namedtuple import enum +import torch +import torch.nn as nn +import torch.nn.functional as F + from modules import sd_models, cache, errors, hashes, shared NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) @@ -115,6 +119,29 @@ class NetworkModule: if hasattr(self.sd_module, 'weight'): self.shape = self.sd_module.weight.shape + self.ops = None + self.extra_kwargs = {} + if isinstance(self.sd_module, nn.Conv2d): + self.ops = F.conv2d + self.extra_kwargs = { + 'stride': self.sd_module.stride, + 'padding': self.sd_module.padding + } + elif isinstance(self.sd_module, nn.Linear): + self.ops = F.linear + elif isinstance(self.sd_module, nn.LayerNorm): + self.ops = F.layer_norm + self.extra_kwargs = { + 'normalized_shape': self.sd_module.normalized_shape, + 'eps': self.sd_module.eps + } + elif isinstance(self.sd_module, nn.GroupNorm): + self.ops = F.group_norm + self.extra_kwargs = { + 'num_groups': self.sd_module.num_groups, + 'eps': self.sd_module.eps + } + self.dim = None self.bias = weights.w.get("bias") self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None @@ -155,5 +182,10 @@ class NetworkModule: raise NotImplementedError() def forward(self, x, y): - raise NotImplementedError() + """A general forward implementation for all modules""" + if self.ops is None: + raise NotImplementedError() + else: + updown, ex_bias = self.calc_updown(self.sd_module.weight) + return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 72ebd624..32e10b62 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_current_names = wanted_names -def network_forward(module, input, original_forward): +def network_forward(org_module, input, original_forward): """ Old way of applying Lora by executing operations during layer's forward. Stacking many loras this way results in big performance degradation. """ if len(loaded_networks) == 0: - return original_forward(module, input) + return original_forward(org_module, input) input = devices.cond_cast_unet(input) - network_restore_weights_from_backup(module) - network_reset_cached_weight(module) + network_restore_weights_from_backup(org_module) + network_reset_cached_weight(org_module) - y = original_forward(module, input) + y = original_forward(org_module, input) - network_layer_name = getattr(module, 'network_layer_name', None) + network_layer_name = getattr(org_module, 'network_layer_name', None) for lora in loaded_networks: module = lora.modules.get(network_layer_name, None) if module is None: -- cgit v1.2.3 From 44744d6005da5e424267698ee3279caa597dfebc Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 5 Jan 2024 16:38:05 +0800 Subject: linting --- extensions-builtin/Lora/network.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index f9b571b5..b8fd9194 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -3,7 +3,6 @@ import os from collections import namedtuple import enum -import torch import torch.nn as nn import torch.nn.functional as F @@ -124,7 +123,7 @@ class NetworkModule: if isinstance(self.sd_module, nn.Conv2d): self.ops = F.conv2d self.extra_kwargs = { - 'stride': self.sd_module.stride, + 'stride': self.sd_module.stride, 'padding': self.sd_module.padding } elif isinstance(self.sd_module, nn.Linear): -- cgit v1.2.3