diff options
Diffstat (limited to 'extensions-builtin')
20 files changed, 898 insertions, 2413 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 6021fd8d..b8fd9194 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -3,6 +3,9 @@ import os from collections import namedtuple
import enum
+import torch.nn as nn
+import torch.nn.functional as F
+
from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
@@ -115,6 +118,29 @@ class NetworkModule: if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
+ self.ops = None
+ self.extra_kwargs = {}
+ if isinstance(self.sd_module, nn.Conv2d):
+ self.ops = F.conv2d
+ self.extra_kwargs = {
+ 'stride': self.sd_module.stride,
+ 'padding': self.sd_module.padding
+ }
+ elif isinstance(self.sd_module, nn.Linear):
+ self.ops = F.linear
+ elif isinstance(self.sd_module, nn.LayerNorm):
+ self.ops = F.layer_norm
+ self.extra_kwargs = {
+ 'normalized_shape': self.sd_module.normalized_shape,
+ 'eps': self.sd_module.eps
+ }
+ elif isinstance(self.sd_module, nn.GroupNorm):
+ self.ops = F.group_norm
+ self.extra_kwargs = {
+ 'num_groups': self.sd_module.num_groups,
+ 'eps': self.sd_module.eps
+ }
+
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
@@ -137,7 +163,7 @@ class NetworkModule: def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
- updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown += self.bias.to(orig_weight.device, dtype=updown.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
@@ -155,5 +181,10 @@ class NetworkModule: raise NotImplementedError()
def forward(self, x, y):
- raise NotImplementedError()
+ """A general forward implementation for all modules"""
+ if self.ops is None:
+ raise NotImplementedError()
+ else:
+ updown, ex_bias = self.calc_updown(self.sd_module.weight)
+ return y + self.ops(x, weight=updown, bias=ex_bias, **self.extra_kwargs)
diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index bf6930e9..f221c95f 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -18,9 +18,9 @@ class NetworkModuleFull(network.NetworkModule): def calc_updown(self, orig_weight):
output_shape = self.weight.shape
- updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ updown = self.weight.to(orig_weight.device)
if self.ex_bias is not None:
- ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype)
+ ex_bias = self.ex_bias.to(orig_weight.device)
else:
ex_bias = None
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py index 492d4870..efe5c681 100644 --- a/extensions-builtin/Lora/network_glora.py +++ b/extensions-builtin/Lora/network_glora.py @@ -22,12 +22,12 @@ class NetworkModuleGLora(network.NetworkModule): self.w2b = weights.w["b2.weight"] def calc_updown(self, orig_weight): - w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) - w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) - w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) - w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w1a = self.w1a.to(orig_weight.device) + w1b = self.w1b.to(orig_weight.device) + w2a = self.w2a.to(orig_weight.device) + w2b = self.w2b.to(orig_weight.device) output_shape = [w1a.size(0), w1b.size(1)] - updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 5fcb0695..d95a0fd1 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,16 +27,16 @@ class NetworkModuleHada(network.NetworkModule): self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight):
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
- t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
+ t1 = self.t1.to(orig_weight.device)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
@@ -45,7 +45,7 @@ class NetworkModuleHada(network.NetworkModule): updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 7edc4249..96faeaf3 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -17,7 +17,7 @@ class NetworkModuleIa3(network.NetworkModule): self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight):
- w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
+ w = self.w.to(orig_weight.device)
output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input:
diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index 340acdab..fcdaeafd 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -37,22 +37,22 @@ class NetworkModuleLokr(network.NetworkModule): def calc_updown(self, orig_weight):
if self.w1 is not None:
- w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1 = self.w1.to(orig_weight.device)
else:
- w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
- w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w1a = self.w1a.to(orig_weight.device)
+ w1b = self.w1b.to(orig_weight.device)
w1 = w1a @ w1b
if self.w2 is not None:
- w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2 = self.w2.to(orig_weight.device)
elif self.t2 is None:
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = w2a @ w2b
else:
- t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
- w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
- w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
+ t2 = self.t2.to(orig_weight.device)
+ w2a = self.w2a.to(orig_weight.device)
+ w2b = self.w2b.to(orig_weight.device)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index 26c0a72c..4cc40295 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -61,13 +61,13 @@ class NetworkModuleLora(network.NetworkModule): return module
def calc_updown(self, orig_weight):
- up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
- down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ up = self.up_model.weight.to(orig_weight.device)
+ down = self.down_model.weight.to(orig_weight.device)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
- mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
+ mid = self.mid_model.weight.to(orig_weight.device)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py index ce450158..d25afcbb 100644 --- a/extensions-builtin/Lora/network_norm.py +++ b/extensions-builtin/Lora/network_norm.py @@ -18,10 +18,10 @@ class NetworkModuleNorm(network.NetworkModule): def calc_updown(self, orig_weight): output_shape = self.w_norm.shape - updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + updown = self.w_norm.to(orig_weight.device) if self.b_norm is not None: - ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + ex_bias = self.b_norm.to(orig_weight.device) else: ex_bias = None diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index fa647020..342fcd0d 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -56,7 +56,7 @@ class NetworkModuleOFT(network.NetworkModule): self.block_size, self.num_blocks = factorization(self.out_dim, self.dim) def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + oft_blocks = self.oft_blocks.to(orig_weight.device) eye = torch.eye(self.block_size, device=self.oft_blocks.device) if self.is_kohya: @@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule): block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse()) - R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + R = oft_blocks.to(orig_weight.device) # This errors out for MultiheadAttention, might need to be handled up-stream merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size) @@ -77,6 +77,6 @@ class NetworkModuleOFT(network.NetworkModule): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight + updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 629bf853..32e10b62 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,3 +1,4 @@ +import gradio as gr
import logging
import os
import re
@@ -314,7 +315,12 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No emb_db.skipped_embeddings[name] = embedding
if failed_to_load_networks:
- sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
+ lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
+ sd_hijack.model_hijack.comments.append(lora_not_found_message)
+ if shared.opts.lora_not_found_warning_console:
+ print(f'\n{lora_not_found_message}\n')
+ if shared.opts.lora_not_found_gradio_warning:
+ gr.Warning(lora_not_found_message)
purge_networks_from_memory()
@@ -389,18 +395,26 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn if module is not None and hasattr(self, 'weight'):
try:
with torch.no_grad():
- updown, ex_bias = module.calc_updown(self.weight)
+ if getattr(self, 'fp16_weight', None) is None:
+ weight = self.weight
+ bias = self.bias
+ else:
+ weight = self.fp16_weight.clone().to(self.weight.device)
+ bias = getattr(self, 'fp16_bias', None)
+ if bias is not None:
+ bias = bias.clone().to(self.bias.device)
+ updown, ex_bias = module.calc_updown(weight)
- if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:
+ if len(weight.shape) == 4 and weight.shape[1] == 9:
# inpainting model. zero pad updown to make channel[1] 4 to 9
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
- self.weight += updown
+ self.weight.copy_((weight.to(dtype=updown.dtype) + updown).to(dtype=self.weight.dtype))
if ex_bias is not None and hasattr(self, 'bias'):
if self.bias is None:
- self.bias = torch.nn.Parameter(ex_bias)
+ self.bias = torch.nn.Parameter(ex_bias).to(self.weight.dtype)
else:
- self.bias += ex_bias
+ self.bias.copy_((bias + ex_bias).to(dtype=self.bias.dtype))
except RuntimeError as e:
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}")
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1
@@ -444,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn self.network_current_names = wanted_names
-def network_forward(module, input, original_forward):
+def network_forward(org_module, input, original_forward):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if len(loaded_networks) == 0:
- return original_forward(module, input)
+ return original_forward(org_module, input)
input = devices.cond_cast_unet(input)
- network_restore_weights_from_backup(module)
- network_reset_cached_weight(module)
+ network_restore_weights_from_backup(org_module)
+ network_reset_cached_weight(org_module)
- y = original_forward(module, input)
+ y = original_forward(org_module, input)
- network_layer_name = getattr(module, 'network_layer_name', None)
+ network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None)
if module is None:
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index ef23968c..1518f7e5 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -39,6 +39,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
+ "lora_not_found_warning_console": shared.OptionInfo(False, "Lora not found warning in console"),
+ "lora_not_found_gradio_warning": shared.OptionInfo(False, "Lora not found warning popup in webui"),
}))
diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index c7011909..3160aecf 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -54,12 +54,13 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.slider_preferred_weight = None
self.edit_notes = None
- def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
+ def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, negative_text, notes):
user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc
user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight
+ user_metadata["negative text"] = negative_text
user_metadata["notes"] = notes
self.write_user_metadata(name, user_metadata)
@@ -127,6 +128,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)),
+ user_metadata.get('negative text', ''),
gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
]
@@ -162,7 +164,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
-
+ self.edit_negative_text = gr.Text(label='Negative prompt', info="Will be added to negative prompts")
with gr.Row() as row_random_prompt:
with gr.Column(scale=8):
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
@@ -198,6 +200,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.taginfo,
self.edit_activation_text,
self.slider_preferred_weight,
+ self.edit_negative_text,
row_random_prompt,
random_prompt,
]
@@ -211,7 +214,9 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.select_sd_version,
self.edit_activation_text,
self.slider_preferred_weight,
+ self.edit_negative_text,
self.edit_notes,
]
+
self.setup_save_handler(self.button_save, self.save_lora_user_metadata, edited_components)
diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index db612fa2..66d15dd0 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -48,6 +48,11 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text)
+ negative_prompt = item["user_metadata"].get("negative text")
+ item["negative_prompt"] = quote_js("")
+ if negative_prompt:
+ item["negative_prompt"] = quote_js('(' + negative_prompt + ':1)')
+
sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 167d2f64..fe5e5a19 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,16 +1,9 @@ import sys import PIL.Image -import numpy as np -import torch -from tqdm import tqdm import modules.upscaler -from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet - -from modules.modelloader import load_file_from_url -from modules.shared import opts +from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils class UpscalerScuNET(modules.upscaler.Upscaler): @@ -42,100 +35,37 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers.append(scaler_data2) self.scalers = scalers - @staticmethod - @torch.no_grad() - def tiled_inference(img, model): - # test the image tile by tile - h, w = img.shape[2:] - tile = opts.SCUNET_tile - tile_overlap = opts.SCUNET_tile_overlap - if tile == 0: - return model(img) - - device = devices.get_device_for('scunet') - assert tile % 8 == 0, "tile size should be a multiple of window_size" - sf = 1 - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: - for h_idx in h_idx_list: - - for w_idx in w_idx_list: - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - def do_upscale(self, img: PIL.Image.Image, selected_file): - devices.torch_gc() - try: model = self.load_model(selected_file) except Exception as e: print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) return img - device = devices.get_device_for('scunet') - tile = opts.SCUNET_tile - h, w = img.height, img.width - np_img = np.array(img) - np_img = np_img[:, :, ::-1] # RGB to BGR - np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW - torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore - - if tile > h or tile > w: - _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) - _img[:, :, :h, :w] = torch_img # pad image - torch_img = _img - - torch_output = self.tiled_inference(torch_img, model).squeeze(0) - torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any - np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() - del torch_img, torch_output + img = upscaler_utils.upscale_2( + img, + model, + tile_size=shared.opts.SCUNET_tile, + tile_overlap=shared.opts.SCUNET_tile_overlap, + scale=1, # ScuNET is a denoising model, not an upscaler + desc='ScuNET', + ) devices.torch_gc() - - output = np_output.transpose((1, 2, 0)) # CHW to HWC - output = output[:, :, ::-1] # BGR to RGB - return PIL.Image.fromarray((output * 255).astype(np.uint8)) + return img def load_model(self, path: str): device = devices.get_device_for('scunet') if path.startswith("http"): # TODO: this doesn't use `path` at all? - filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") + filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for _, v in model.named_parameters(): - v.requires_grad = False - model = model.to(device) - - return model + return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet') def on_ui_settings(): import gradio as gr - from modules import shared shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) diff --git a/extensions-builtin/ScuNET/scunet_model_arch.py b/extensions-builtin/ScuNET/scunet_model_arch.py deleted file mode 100644 index b51a8806..00000000 --- a/extensions-builtin/ScuNET/scunet_model_arch.py +++ /dev/null @@ -1,268 +0,0 @@ -# -*- coding: utf-8 -*- -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_, DropPath - - -class WMSA(nn.Module): - """ Self-attention module in Swin Transformer - """ - - def __init__(self, input_dim, output_dim, head_dim, window_size, type): - super(WMSA, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.head_dim = head_dim - self.scale = self.head_dim ** -0.5 - self.n_heads = input_dim // head_dim - self.window_size = window_size - self.type = type - self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True) - - self.relative_position_params = nn.Parameter( - torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)) - - self.linear = nn.Linear(self.input_dim, self.output_dim) - - trunc_normal_(self.relative_position_params, std=.02) - self.relative_position_params = torch.nn.Parameter( - self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1, - 2).transpose( - 0, 1)) - - def generate_mask(self, h, w, p, shift): - """ generating the mask of SW-MSA - Args: - shift: shift parameters in CyclicShift. - Returns: - attn_mask: should be (1 1 w p p), - """ - # supporting square. - attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) - if self.type == 'W': - return attn_mask - - s = p - shift - attn_mask[-1, :, :s, :, s:, :] = True - attn_mask[-1, :, s:, :, :s, :] = True - attn_mask[:, -1, :, :s, :, s:] = True - attn_mask[:, -1, :, s:, :, :s] = True - attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)') - return attn_mask - - def forward(self, x): - """ Forward pass of Window Multi-head Self-attention module. - Args: - x: input tensor with shape of [b h w c]; - attn_mask: attention mask, fill -inf where the value is True; - Returns: - output: tensor shape [b h w c] - """ - if self.type != 'W': - x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2)) - - x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) - h_windows = x.size(1) - w_windows = x.size(2) - # square validation - # assert h_windows == w_windows - - x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) - qkv = self.embedding_layer(x) - q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0) - sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale - # Adding learnable relative embedding - sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q') - # Using Attn Mask to distinguish different subwindows. - if self.type != 'W': - attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2) - sim = sim.masked_fill_(attn_mask, float("-inf")) - - probs = nn.functional.softmax(sim, dim=-1) - output = torch.einsum('hbwij,hbwjc->hbwic', probs, v) - output = rearrange(output, 'h b w p c -> b w p (h c)') - output = self.linear(output) - output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size) - - if self.type != 'W': - output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2), dims=(1, 2)) - - return output - - def relative_embedding(self): - cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)])) - relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1 - # negative is allowed - return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()] - - -class Block(nn.Module): - def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer Block - """ - super(Block, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - assert type in ['W', 'SW'] - self.type = type - if input_resolution <= window_size: - self.type = 'W' - - self.ln1 = nn.LayerNorm(input_dim) - self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.ln2 = nn.LayerNorm(input_dim) - self.mlp = nn.Sequential( - nn.Linear(input_dim, 4 * input_dim), - nn.GELU(), - nn.Linear(4 * input_dim, output_dim), - ) - - def forward(self, x): - x = x + self.drop_path(self.msa(self.ln1(x))) - x = x + self.drop_path(self.mlp(self.ln2(x))) - return x - - -class ConvTransBlock(nn.Module): - def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None): - """ SwinTransformer and Conv Block - """ - super(ConvTransBlock, self).__init__() - self.conv_dim = conv_dim - self.trans_dim = trans_dim - self.head_dim = head_dim - self.window_size = window_size - self.drop_path = drop_path - self.type = type - self.input_resolution = input_resolution - - assert self.type in ['W', 'SW'] - if self.input_resolution <= self.window_size: - self.type = 'W' - - self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path, - self.type, self.input_resolution) - self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True) - - self.conv_block = nn.Sequential( - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False), - nn.ReLU(True), - nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False) - ) - - def forward(self, x): - conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1) - conv_x = self.conv_block(conv_x) + conv_x - trans_x = Rearrange('b c h w -> b h w c')(trans_x) - trans_x = self.trans_block(trans_x) - trans_x = Rearrange('b h w c -> b c h w')(trans_x) - res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1)) - x = x + res - - return x - - -class SCUNet(nn.Module): - # def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256): - def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256): - super(SCUNet, self).__init__() - if config is None: - config = [2, 2, 2, 2, 2, 2, 2] - self.config = config - self.dim = dim - self.head_dim = 32 - self.window_size = 8 - - # drop path rate for each layer - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))] - - self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)] - - begin = 0 - self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[0])] + \ - [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)] - - begin += config[0] - self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[1])] + \ - [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)] - - begin += config[1] - self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[2])] + \ - [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)] - - begin += config[2] - self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 8) - for i in range(config[3])] - - begin += config[3] - self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 4) - for i in range(config[4])] - - begin += config[4] - self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution // 2) - for i in range(config[5])] - - begin += config[5] - self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \ - [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin], - 'W' if not i % 2 else 'SW', input_resolution) - for i in range(config[6])] - - self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)] - - self.m_head = nn.Sequential(*self.m_head) - self.m_down1 = nn.Sequential(*self.m_down1) - self.m_down2 = nn.Sequential(*self.m_down2) - self.m_down3 = nn.Sequential(*self.m_down3) - self.m_body = nn.Sequential(*self.m_body) - self.m_up3 = nn.Sequential(*self.m_up3) - self.m_up2 = nn.Sequential(*self.m_up2) - self.m_up1 = nn.Sequential(*self.m_up1) - self.m_tail = nn.Sequential(*self.m_tail) - # self.apply(self._init_weights) - - def forward(self, x0): - - h, w = x0.size()[-2:] - paddingBottom = int(np.ceil(h / 64) * 64 - h) - paddingRight = int(np.ceil(w / 64) * 64 - w) - x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0) - - x1 = self.m_head(x0) - x2 = self.m_down1(x1) - x3 = self.m_down2(x2) - x4 = self.m_down3(x3) - x = self.m_body(x4) - x = self.m_up3(x + x4) - x = self.m_up2(x + x3) - x = self.m_up1(x + x2) - x = self.m_tail(x + x1) - - x = x[..., :h, :w] - - return x - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index ae0d0e6a..16bf9b79 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,20 +1,15 @@ +import logging import sys -import platform -import numpy as np import torch from PIL import Image -from tqdm import tqdm -from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts, state -from swinir_model_arch import SwinIR -from swinir_model_arch_v2 import Swin2SR +from modules import devices, modelloader, script_callbacks, shared, upscaler_utils from modules.upscaler import Upscaler, UpscalerData SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" -device_swinir = devices.get_device_for('swinir') +logger = logging.getLogger(__name__) class UpscalerSwinIR(Upscaler): @@ -37,26 +32,28 @@ class UpscalerSwinIR(Upscaler): scalers.append(model_data) self.scalers = scalers - def do_upscale(self, img, model_file): - use_compile = hasattr(opts, 'SWIN_torch_compile') and opts.SWIN_torch_compile \ - and int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows" - current_config = (model_file, opts.SWIN_tile) + def do_upscale(self, img: Image.Image, model_file: str) -> Image.Image: + current_config = (model_file, shared.opts.SWIN_tile) - if use_compile and self._cached_model_config == current_config: + if self._cached_model_config == current_config: model = self._cached_model else: - self._cached_model = None try: model = self.load_model(model_file) except Exception as e: print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) return img - model = model.to(device_swinir, dtype=devices.dtype) - if use_compile: - model = torch.compile(model) - self._cached_model = model - self._cached_model_config = current_config - img = upscale(img, model) + self._cached_model = model + self._cached_model_config = current_config + + img = upscaler_utils.upscale_2( + img, + model, + tile_size=shared.opts.SWIN_tile, + tile_overlap=shared.opts.SWIN_tile_overlap, + scale=model.scale, + desc="SwinIR", + ) devices.torch_gc() return img @@ -69,115 +66,22 @@ class UpscalerSwinIR(Upscaler): ) else: filename = path - if filename.endswith(".v2.pth"): - model = Swin2SR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="1conv", - ) - params = None - else: - model = SwinIR( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="3conv", - ) - params = "params_ema" - pretrained_model = torch.load(filename) - if params is not None: - model.load_state_dict(pretrained_model[params], strict=True) - else: - model.load_state_dict(pretrained_model, strict=True) - return model - - -def upscale( - img, - model, - tile=None, - tile_overlap=None, - window_size=8, - scale=4, -): - tile = tile or opts.SWIN_tile - tile_overlap = tile_overlap or opts.SWIN_tile_overlap - - - img = np.array(img) - img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype) - with torch.no_grad(), devices.autocast(): - _, _, h_old, w_old = img.size() - h_pad = (h_old // window_size + 1) * window_size - h_old - w_pad = (w_old // window_size + 1) * window_size - w_old - img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] - img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference(img, model, tile, tile_overlap, window_size, scale) - output = output[..., : h_old * scale, : w_old * scale] - output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() - if output.ndim == 3: - output = np.transpose( - output[[2, 1, 0], :, :], (1, 2, 0) - ) # CHW-RGB to HCW-BGR - output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 - return Image.fromarray(output, "RGB") - - -def inference(img, model, tile, tile_overlap, window_size, scale): - # test the image tile by tile - b, c, h, w = img.size() - tile = min(tile, h, w) - assert tile % window_size == 0, "tile size should be a multiple of window_size" - sf = scale - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: - for h_idx in h_idx_list: - if state.interrupted or state.skipped: - break - - for w_idx in w_idx_list: - if state.interrupted or state.skipped: - break - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output + model_descriptor = modelloader.load_spandrel_model( + filename, + device=self._get_device(), + prefer_half=(devices.dtype == torch.float16), + expected_architecture="SwinIR", + ) + if getattr(shared.opts, 'SWIN_torch_compile', False): + try: + model_descriptor.model.compile() + except Exception: + logger.warning("Failed to compile SwinIR model, fallback to JIT", exc_info=True) + return model_descriptor + + def _get_device(self): + return devices.get_device_for('swinir') def on_ui_settings(): @@ -185,8 +89,7 @@ def on_ui_settings(): shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling"))) shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling"))) - if int(torch.__version__.split('.')[0]) >= 2 and platform.system() != "Windows": # torch.compile() require pytorch 2.0 or above, and not on Windows - shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) + shared.opts.add_option("SWIN_torch_compile", shared.OptionInfo(False, "Use torch.compile to accelerate SwinIR.", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")).info("Takes longer on first run")) script_callbacks.on_ui_settings(on_ui_settings) diff --git a/extensions-builtin/SwinIR/swinir_model_arch.py b/extensions-builtin/SwinIR/swinir_model_arch.py deleted file mode 100644 index 93b93274..00000000 --- a/extensions-builtin/SwinIR/swinir_model_arch.py +++ /dev/null @@ -1,867 +0,0 @@ -# ----------------------------------------------------------------------------------- -# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 -# Originally Written by Ze Liu, Modified by Jingyun Liang. -# ----------------------------------------------------------------------------------- - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -def window_partition(x, window_size): - """ - Args: - x: (B, H, W, C) - window_size (int): window size - - Returns: - windows: (num_windows*B, window_size, window_size, C) - """ - B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) - return windows - - -def window_reverse(windows, window_size, H, W): - """ - Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image - - Returns: - x: (B, H, W, C) - """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) - x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) - return x - - -class WindowAttention(nn.Module): - r""" Window based multi-head self attention (W-MSA) module with relative position bias. - It supports both of shifted and non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set - attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - """ - - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): - - super().__init__() - self.dim = dim - self.window_size = window_size # Wh, Ww - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 - - # define a parameter table of relative position bias - self.relative_position_bias_table = nn.Parameter( - torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - - self.proj_drop = nn.Dropout(proj_drop) - - trunc_normal_(self.relative_position_bias_table, std=.02) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None - """ - B_, N, C = x.shape - qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = (q @ k.transpose(-2, -1)) - - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops - - -class SwinTransformerBlock(nn.Module): - r""" Swin Transformer Block. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - num_heads (int): Number of attention heads. - window_size (int): Window size. - shift_size (int): Shift size for SW-MSA. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float, optional): Stochastic depth rate. Default: 0.0 - act_layer (nn.Module, optional): Activation layer. Default: nn.GELU - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm): - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size - self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" - - self.norm1 = norm_layer(dim) - self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - - if self.shift_size > 0: - attn_mask = self.calculate_mask(self.input_resolution) - else: - attn_mask = None - - self.register_buffer("attn_mask", attn_mask) - - def calculate_mask(self, x_size): - # calculate attention mask for SW-MSA - H, W = x_size - img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 - h_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - w_slices = (slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)) - cnt = 0 - for h in h_slices: - for w in w_slices: - img_mask[:, h, w, :] = cnt - cnt += 1 - - mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) - attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) - - return attn_mask - - def forward(self, x, x_size): - H, W = x_size - B, L, C = x.shape - # assert L == H * W, "input feature has wrong size" - - shortcut = x - x = self.norm1(x) - x = x.view(B, H, W, C) - - # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) - else: - shifted_x = x - - # partition windows - x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C - - # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size - if self.input_resolution == x_size: - attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C - else: - attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) - - # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C - - # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) - else: - x = shifted_x - x = x.view(B, H * W, C) - - # FFN - x = shortcut + self.drop_path(x) - x = x + self.drop_path(self.mlp(self.norm2(x))) - - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ - f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" - - def flops(self): - flops = 0 - H, W = self.input_resolution - # norm1 - flops += self.dim * H * W - # W-MSA/SW-MSA - nW = H * W / self.window_size / self.window_size - flops += nW * self.attn.flops(self.window_size * self.window_size) - # mlp - flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio - # norm2 - flops += self.dim * H * W - return flops - - -class PatchMerging(nn.Module): - r""" Patch Merging Layer. - - Args: - input_resolution (tuple[int]): Resolution of input feature. - dim (int): Number of input channels. - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - """ - - def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): - super().__init__() - self.input_resolution = input_resolution - self.dim = dim - self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) - self.norm = norm_layer(4 * dim) - - def forward(self, x): - """ - x: B, H*W, C - """ - H, W = self.input_resolution - B, L, C = x.shape - assert L == H * W, "input feature has wrong size" - assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." - - x = x.view(B, H, W, C) - - x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C - x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C - x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C - x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C - x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C - x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C - - x = self.norm(x) - x = self.reduction(x) - - return x - - def extra_repr(self) -> str: - return f"input_resolution={self.input_resolution}, dim={self.dim}" - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.dim - flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim - return flops - - -class BasicLayer(nn.Module): - """ A basic Swin Transformer layer for one stage. - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): - - super().__init__() - self.dim = dim - self.input_resolution = input_resolution - self.depth = depth - self.use_checkpoint = use_checkpoint - - # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) - for i in range(depth)]) - - # patch merging layer - if downsample is not None: - self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) - else: - self.downsample = None - - def forward(self, x, x_size): - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, x_size) - else: - x = blk(x, x_size) - if self.downsample is not None: - x = self.downsample(x) - return x - - def extra_repr(self) -> str: - return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" - - def flops(self): - flops = 0 - for blk in self.blocks: - flops += blk.flops() - if self.downsample is not None: - flops += self.downsample.flops() - return flops - - -class RSTB(nn.Module): - """Residual Swin Transformer Block (RSTB). - - Args: - dim (int): Number of input channels. - input_resolution (tuple[int]): Input resolution. - depth (int): Number of blocks. - num_heads (int): Number of attention heads. - window_size (int): Local window size. - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. - qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. - drop (float, optional): Dropout rate. Default: 0.0 - attn_drop (float, optional): Attention dropout rate. Default: 0.0 - drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 - norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm - downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. - img_size: Input image size. - patch_size: Patch size. - resi_connection: The convolutional block before residual connection. - """ - - def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, - img_size=224, patch_size=4, resi_connection='1conv'): - super(RSTB, self).__init__() - - self.dim = dim - self.input_resolution = input_resolution - - self.residual_group = BasicLayer(dim=dim, - input_resolution=input_resolution, - depth=depth, - num_heads=num_heads, - window_size=window_size, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path, - norm_layer=norm_layer, - downsample=downsample, - use_checkpoint=use_checkpoint) - - if resi_connection == '1conv': - self.conv = nn.Conv2d(dim, dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(dim // 4, dim, 3, 1, 1)) - - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, - norm_layer=None) - - def forward(self, x, x_size): - return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x - - def flops(self): - flops = 0 - flops += self.residual_group.flops() - H, W = self.input_resolution - flops += H * W * self.dim * self.dim * 9 - flops += self.patch_embed.flops() - flops += self.patch_unembed.flops() - - return flops - - -class PatchEmbed(nn.Module): - r""" Image to Patch Embedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - if norm_layer is not None: - self.norm = norm_layer(embed_dim) - else: - self.norm = None - - def forward(self, x): - x = x.flatten(2).transpose(1, 2) # B Ph*Pw C - if self.norm is not None: - x = self.norm(x) - return x - - def flops(self): - flops = 0 - H, W = self.img_size - if self.norm is not None: - flops += H * W * self.embed_dim - return flops - - -class PatchUnEmbed(nn.Module): - r""" Image to Patch Unembedding - - Args: - img_size (int): Image size. Default: 224. - patch_size (int): Patch token size. Default: 4. - in_chans (int): Number of input image channels. Default: 3. - embed_dim (int): Number of linear projection output channels. Default: 96. - norm_layer (nn.Module, optional): Normalization layer. Default: None - """ - - def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] - self.img_size = img_size - self.patch_size = patch_size - self.patches_resolution = patches_resolution - self.num_patches = patches_resolution[0] * patches_resolution[1] - - self.in_chans = in_chans - self.embed_dim = embed_dim - - def forward(self, x, x_size): - B, HW, C = x.shape - x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C - return x - - def flops(self): - flops = 0 - return flops - - -class Upsample(nn.Sequential): - """Upsample module. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - """ - - def __init__(self, scale, num_feat): - m = [] - if (scale & (scale - 1)) == 0: # scale = 2^n - for _ in range(int(math.log(scale, 2))): - m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(2)) - elif scale == 3: - m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) - m.append(nn.PixelShuffle(3)) - else: - raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') - super(Upsample, self).__init__(*m) - - -class UpsampleOneStep(nn.Sequential): - """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) - Used in lightweight SR to save parameters. - - Args: - scale (int): Scale factor. Supported scales: 2^n and 3. - num_feat (int): Channel number of intermediate features. - - """ - - def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): - self.num_feat = num_feat - self.input_resolution = input_resolution - m = [] - m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) - m.append(nn.PixelShuffle(scale)) - super(UpsampleOneStep, self).__init__(*m) - - def flops(self): - H, W = self.input_resolution - flops = H * W * self.num_feat * 3 * 9 - return flops - - -class SwinIR(nn.Module): - r""" SwinIR - A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. - - Args: - img_size (int | tuple(int)): Input image size. Default 64 - patch_size (int | tuple(int)): Patch size. Default: 1 - in_chans (int): Number of input image channels. Default: 3 - embed_dim (int): Patch embedding dimension. Default: 96 - depths (tuple(int)): Depth of each Swin Transformer layer. - num_heads (tuple(int)): Number of attention heads in different layers. - window_size (int): Window size. Default: 7 - mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 - qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None - drop_rate (float): Dropout rate. Default: 0 - attn_drop_rate (float): Attention dropout rate. Default: 0 - drop_path_rate (float): Stochastic depth rate. Default: 0.1 - norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. - ape (bool): If True, add absolute position embedding to the patch embedding. Default: False - patch_norm (bool): If True, add normalization after patch embedding. Default: True - use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False - upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction - img_range: Image range. 1. or 255. - upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None - resi_connection: The convolutional block before residual connection. '1conv'/'3conv' - """ - - def __init__(self, img_size=64, patch_size=1, in_chans=3, - embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6), - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', - **kwargs): - super(SwinIR, self).__init__() - num_in_ch = in_chans - num_out_ch = in_chans - num_feat = 64 - self.img_range = img_range - if in_chans == 3: - rgb_mean = (0.4488, 0.4371, 0.4040) - self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) - else: - self.mean = torch.zeros(1, 1, 1, 1) - self.upscale = upscale - self.upsampler = upsampler - self.window_size = window_size - - ##################################################################################################### - ################################### 1, shallow feature extraction ################################### - self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) - - ##################################################################################################### - ################################### 2, deep feature extraction ###################################### - self.num_layers = len(depths) - self.embed_dim = embed_dim - self.ape = ape - self.patch_norm = patch_norm - self.num_features = embed_dim - self.mlp_ratio = mlp_ratio - - # split image into non-overlapping patches - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - num_patches = self.patch_embed.num_patches - patches_resolution = self.patch_embed.patches_resolution - self.patches_resolution = patches_resolution - - # merge non-overlapping patches into image - self.patch_unembed = PatchUnEmbed( - img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, - norm_layer=norm_layer if self.patch_norm else None) - - # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - - self.pos_drop = nn.Dropout(p=drop_rate) - - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - - # build Residual Swin Transformer blocks (RSTB) - self.layers = nn.ModuleList() - for i_layer in range(self.num_layers): - layer = RSTB(dim=embed_dim, - input_resolution=(patches_resolution[0], - patches_resolution[1]), - depth=depths[i_layer], - num_heads=num_heads[i_layer], - window_size=window_size, - mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, - drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results - norm_layer=norm_layer, - downsample=None, - use_checkpoint=use_checkpoint, - img_size=img_size, - patch_size=patch_size, - resi_connection=resi_connection - - ) - self.layers.append(layer) - self.norm = norm_layer(self.num_features) - - # build the last conv layer in deep feature extraction - if resi_connection == '1conv': - self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) - elif resi_connection == '3conv': - # to save parameters and memory - self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) - - ##################################################################################################### - ################################ 3, high quality image reconstruction ################################ - if self.upsampler == 'pixelshuffle': - # for classical SR - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.upsample = Upsample(upscale, num_feat) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR (to save parameters) - self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, - (patches_resolution[0], patches_resolution[1])) - elif self.upsampler == 'nearest+conv': - # for real-world SR (less artifacts) - self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), - nn.LeakyReLU(inplace=True)) - self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - if self.upscale == 4: - self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) - self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) - self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - else: - # for image denoising and JPEG compression artifact reduction - self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - @torch.jit.ignore - def no_weight_decay(self): - return {'absolute_pos_embed'} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - return {'relative_position_bias_table'} - - def check_image_size(self, x): - _, _, h, w = x.size() - mod_pad_h = (self.window_size - h % self.window_size) % self.window_size - mod_pad_w = (self.window_size - w % self.window_size) % self.window_size - x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') - return x - - def forward_features(self, x): - x_size = (x.shape[2], x.shape[3]) - x = self.patch_embed(x) - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x, x_size) - - x = self.norm(x) # B L C - x = self.patch_unembed(x, x_size) - - return x - - def forward(self, x): - H, W = x.shape[2:] - x = self.check_image_size(x) - - self.mean = self.mean.type_as(x) - x = (x - self.mean) * self.img_range - - if self.upsampler == 'pixelshuffle': - # for classical SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.conv_last(self.upsample(x)) - elif self.upsampler == 'pixelshuffledirect': - # for lightweight SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.upsample(x) - elif self.upsampler == 'nearest+conv': - # for real-world SR - x = self.conv_first(x) - x = self.conv_after_body(self.forward_features(x)) + x - x = self.conv_before_upsample(x) - x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - if self.upscale == 4: - x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) - x = self.conv_last(self.lrelu(self.conv_hr(x))) - else: - # for image denoising and JPEG compression artifact reduction - x_first = self.conv_first(x) - res = self.conv_after_body(self.forward_features(x_first)) + x_first - x = x + self.conv_last(res) - - x = x / self.img_range + self.mean - - return x[:, :, :H*self.upscale, :W*self.upscale] - - def flops(self): - flops = 0 - H, W = self.patches_resolution - flops += H * W * 3 * self.embed_dim * 9 - flops += self.patch_embed.flops() - for layer in self.layers: - flops += layer.flops() - flops += H * W * 3 * self.embed_dim * self.embed_dim - flops += self.upsample.flops() - return flops - - -if __name__ == '__main__': - upscale = 4 - window_size = 8 - height = (1024 // upscale // window_size + 1) * window_size - width = (720 // upscale // window_size + 1) * window_size - model = SwinIR(upscale=2, img_size=(height, width), - window_size=window_size, img_range=1., depths=[6, 6, 6, 6], - embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') - print(model) - print(height, width, model.flops() / 1e9) - - x = torch.randn((1, 3, height, width)) - x = model(x) - print(x.shape) diff --git a/extensions-builtin/SwinIR/swinir_model_arch_v2.py b/extensions-builtin/SwinIR/swinir_model_arch_v2.py deleted file mode 100644 index dad22cca..00000000 --- a/extensions-builtin/SwinIR/swinir_model_arch_v2.py +++ /dev/null @@ -1,1017 +0,0 @@ -# -----------------------------------------------------------------------------------
-# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/
-# Written by Conde and Choi et al.
-# -----------------------------------------------------------------------------------
-
-import math
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-import torch.utils.checkpoint as checkpoint
-from timm.models.layers import DropPath, to_2tuple, trunc_normal_
-
-
-class Mlp(nn.Module):
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Linear(in_features, hidden_features)
- self.act = act_layer()
- self.fc2 = nn.Linear(hidden_features, out_features)
- self.drop = nn.Dropout(drop)
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- x = self.drop(x)
- return x
-
-
-def window_partition(x, window_size):
- """
- Args:
- x: (B, H, W, C)
- window_size (int): window size
- Returns:
- windows: (num_windows*B, window_size, window_size, C)
- """
- B, H, W, C = x.shape
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- return windows
-
-
-def window_reverse(windows, window_size, H, W):
- """
- Args:
- windows: (num_windows*B, window_size, window_size, C)
- window_size (int): Window size
- H (int): Height of image
- W (int): Width of image
- Returns:
- x: (B, H, W, C)
- """
- B = int(windows.shape[0] / (H * W / window_size / window_size))
- x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
- return x
-
-class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
- It supports both of shifted and non-shifted window.
- Args:
- dim (int): Number of input channels.
- window_size (tuple[int]): The height and width of the window.
- num_heads (int): Number of attention heads.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
- pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
- """
-
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
- pretrained_window_size=(0, 0)):
-
- super().__init__()
- self.dim = dim
- self.window_size = window_size # Wh, Ww
- self.pretrained_window_size = pretrained_window_size
- self.num_heads = num_heads
-
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
-
- # mlp to generate continuous relative position bias
- self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
- nn.ReLU(inplace=True),
- nn.Linear(512, num_heads, bias=False))
-
- # get relative_coords_table
- relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
- relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
- relative_coords_table = torch.stack(
- torch.meshgrid([relative_coords_h,
- relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
- if pretrained_window_size[0] > 0:
- relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
- else:
- relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
- relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
- relative_coords_table *= 8 # normalize to -8, 8
- relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
- torch.abs(relative_coords_table) + 1.0) / np.log2(8)
-
- self.register_buffer("relative_coords_table", relative_coords_table)
-
- # get pair-wise relative position index for each token inside the window
- coords_h = torch.arange(self.window_size[0])
- coords_w = torch.arange(self.window_size[1])
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
- relative_coords[:, :, 1] += self.window_size[1] - 1
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
- self.register_buffer("relative_position_index", relative_position_index)
-
- self.qkv = nn.Linear(dim, dim * 3, bias=False)
- if qkv_bias:
- self.q_bias = nn.Parameter(torch.zeros(dim))
- self.v_bias = nn.Parameter(torch.zeros(dim))
- else:
- self.q_bias = None
- self.v_bias = None
- self.attn_drop = nn.Dropout(attn_drop)
- self.proj = nn.Linear(dim, dim)
- self.proj_drop = nn.Dropout(proj_drop)
- self.softmax = nn.Softmax(dim=-1)
-
- def forward(self, x, mask=None):
- """
- Args:
- x: input features with shape of (num_windows*B, N, C)
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
- """
- B_, N, C = x.shape
- qkv_bias = None
- if self.q_bias is not None:
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
- qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
-
- # cosine attention
- attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
- logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp()
- attn = attn * logit_scale
-
- relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
- relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
- relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
- attn = attn + relative_position_bias.unsqueeze(0)
-
- if mask is not None:
- nW = mask.shape[0]
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
- attn = attn.view(-1, self.num_heads, N, N)
- attn = self.softmax(attn)
- else:
- attn = self.softmax(attn)
-
- attn = self.attn_drop(attn)
-
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
- def extra_repr(self) -> str:
- return f'dim={self.dim}, window_size={self.window_size}, ' \
- f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}'
-
- def flops(self, N):
- # calculate flops for 1 window with token length of N
- flops = 0
- # qkv = self.qkv(x)
- flops += N * self.dim * 3 * self.dim
- # attn = (q @ k.transpose(-2, -1))
- flops += self.num_heads * N * (self.dim // self.num_heads) * N
- # x = (attn @ v)
- flops += self.num_heads * N * N * (self.dim // self.num_heads)
- # x = self.proj(x)
- flops += N * self.dim * self.dim
- return flops
-
-class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resulotion.
- num_heads (int): Number of attention heads.
- window_size (int): Window size.
- shift_size (int): Shift size for SW-MSA.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- pretrained_window_size (int): Window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
- act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.num_heads = num_heads
- self.window_size = window_size
- self.shift_size = shift_size
- self.mlp_ratio = mlp_ratio
- if min(self.input_resolution) <= self.window_size:
- # if window size is larger than input resolution, we don't partition windows
- self.shift_size = 0
- self.window_size = min(self.input_resolution)
- assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
-
- self.norm1 = norm_layer(dim)
- self.attn = WindowAttention(
- dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
- qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
- pretrained_window_size=to_2tuple(pretrained_window_size))
-
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
-
- if self.shift_size > 0:
- attn_mask = self.calculate_mask(self.input_resolution)
- else:
- attn_mask = None
-
- self.register_buffer("attn_mask", attn_mask)
-
- def calculate_mask(self, x_size):
- # calculate attention mask for SW-MSA
- H, W = x_size
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
- h_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size),
- slice(-self.window_size, -self.shift_size),
- slice(-self.shift_size, None))
- cnt = 0
- for h in h_slices:
- for w in w_slices:
- img_mask[:, h, w, :] = cnt
- cnt += 1
-
- mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
-
- return attn_mask
-
- def forward(self, x, x_size):
- H, W = x_size
- B, L, C = x.shape
- #assert L == H * W, "input feature has wrong size"
-
- shortcut = x
- x = x.view(B, H, W, C)
-
- # cyclic shift
- if self.shift_size > 0:
- shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
- else:
- shifted_x = x
-
- # partition windows
- x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
- x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
-
- # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
- if self.input_resolution == x_size:
- attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
- else:
- attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
-
- # merge windows
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
-
- # reverse cyclic shift
- if self.shift_size > 0:
- x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
- else:
- x = shifted_x
- x = x.view(B, H * W, C)
- x = shortcut + self.drop_path(self.norm1(x))
-
- # FFN
- x = x + self.drop_path(self.norm2(self.mlp(x)))
-
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
-
- def flops(self):
- flops = 0
- H, W = self.input_resolution
- # norm1
- flops += self.dim * H * W
- # W-MSA/SW-MSA
- nW = H * W / self.window_size / self.window_size
- flops += nW * self.attn.flops(self.window_size * self.window_size)
- # mlp
- flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
- # norm2
- flops += self.dim * H * W
- return flops
-
-class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
- Args:
- input_resolution (tuple[int]): Resolution of input feature.
- dim (int): Number of input channels.
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- """
-
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
- super().__init__()
- self.input_resolution = input_resolution
- self.dim = dim
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
- self.norm = norm_layer(2 * dim)
-
- def forward(self, x):
- """
- x: B, H*W, C
- """
- H, W = self.input_resolution
- B, L, C = x.shape
- assert L == H * W, "input feature has wrong size"
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
-
- x = x.view(B, H, W, C)
-
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
-
- x = self.reduction(x)
- x = self.norm(x)
-
- return x
-
- def extra_repr(self) -> str:
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
-
- def flops(self):
- H, W = self.input_resolution
- flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
- flops += H * W * self.dim // 2
- return flops
-
-class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- pretrained_window_size (int): Local window size in pre-training.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- pretrained_window_size=0):
-
- super().__init__()
- self.dim = dim
- self.input_resolution = input_resolution
- self.depth = depth
- self.use_checkpoint = use_checkpoint
-
- # build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
- num_heads=num_heads, window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer,
- pretrained_window_size=pretrained_window_size)
- for i in range(depth)])
-
- # patch merging layer
- if downsample is not None:
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
- else:
- self.downsample = None
-
- def forward(self, x, x_size):
- for blk in self.blocks:
- if self.use_checkpoint:
- x = checkpoint.checkpoint(blk, x, x_size)
- else:
- x = blk(x, x_size)
- if self.downsample is not None:
- x = self.downsample(x)
- return x
-
- def extra_repr(self) -> str:
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
-
- def flops(self):
- flops = 0
- for blk in self.blocks:
- flops += blk.flops()
- if self.downsample is not None:
- flops += self.downsample.flops()
- return flops
-
- def _init_respostnorm(self):
- for blk in self.blocks:
- nn.init.constant_(blk.norm1.bias, 0)
- nn.init.constant_(blk.norm1.weight, 0)
- nn.init.constant_(blk.norm2.bias, 0)
- nn.init.constant_(blk.norm2.weight, 0)
-
-class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
- if norm_layer is not None:
- self.norm = norm_layer(embed_dim)
- else:
- self.norm = None
-
- def forward(self, x):
- B, C, H, W = x.shape
- # FIXME look at relaxing size constraints
- # assert H == self.img_size[0] and W == self.img_size[1],
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
- x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
- if self.norm is not None:
- x = self.norm(x)
- return x
-
- def flops(self):
- Ho, Wo = self.patches_resolution
- flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
- if self.norm is not None:
- flops += Ho * Wo * self.embed_dim
- return flops
-
-class RSTB(nn.Module):
- """Residual Swin Transformer Block (RSTB).
-
- Args:
- dim (int): Number of input channels.
- input_resolution (tuple[int]): Input resolution.
- depth (int): Number of blocks.
- num_heads (int): Number of attention heads.
- window_size (int): Local window size.
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
- drop (float, optional): Dropout rate. Default: 0.0
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
- img_size: Input image size.
- patch_size: Patch size.
- resi_connection: The convolutional block before residual connection.
- """
-
- def __init__(self, dim, input_resolution, depth, num_heads, window_size,
- mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
- drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
- img_size=224, patch_size=4, resi_connection='1conv'):
- super(RSTB, self).__init__()
-
- self.dim = dim
- self.input_resolution = input_resolution
-
- self.residual_group = BasicLayer(dim=dim,
- input_resolution=input_resolution,
- depth=depth,
- num_heads=num_heads,
- window_size=window_size,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop, attn_drop=attn_drop,
- drop_path=drop_path,
- norm_layer=norm_layer,
- downsample=downsample,
- use_checkpoint=use_checkpoint)
-
- if resi_connection == '1conv':
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
-
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim,
- norm_layer=None)
-
- def forward(self, x, x_size):
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
-
- def flops(self):
- flops = 0
- flops += self.residual_group.flops()
- H, W = self.input_resolution
- flops += H * W * self.dim * self.dim * 9
- flops += self.patch_embed.flops()
- flops += self.patch_unembed.flops()
-
- return flops
-
-class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
-
- Args:
- img_size (int): Image size. Default: 224.
- patch_size (int): Patch token size. Default: 4.
- in_chans (int): Number of input image channels. Default: 3.
- embed_dim (int): Number of linear projection output channels. Default: 96.
- norm_layer (nn.Module, optional): Normalization layer. Default: None
- """
-
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
- super().__init__()
- img_size = to_2tuple(img_size)
- patch_size = to_2tuple(patch_size)
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
- self.img_size = img_size
- self.patch_size = patch_size
- self.patches_resolution = patches_resolution
- self.num_patches = patches_resolution[0] * patches_resolution[1]
-
- self.in_chans = in_chans
- self.embed_dim = embed_dim
-
- def forward(self, x, x_size):
- B, HW, C = x.shape
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
- return x
-
- def flops(self):
- flops = 0
- return flops
-
-
-class Upsample(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample, self).__init__(*m)
-
-class Upsample_hf(nn.Sequential):
- """Upsample module.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
- """
-
- def __init__(self, scale, num_feat):
- m = []
- if (scale & (scale - 1)) == 0: # scale = 2^n
- for _ in range(int(math.log(scale, 2))):
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(2))
- elif scale == 3:
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
- m.append(nn.PixelShuffle(3))
- else:
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
- super(Upsample_hf, self).__init__(*m)
-
-
-class UpsampleOneStep(nn.Sequential):
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
- Used in lightweight SR to save parameters.
-
- Args:
- scale (int): Scale factor. Supported scales: 2^n and 3.
- num_feat (int): Channel number of intermediate features.
-
- """
-
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
- self.num_feat = num_feat
- self.input_resolution = input_resolution
- m = []
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
- m.append(nn.PixelShuffle(scale))
- super(UpsampleOneStep, self).__init__(*m)
-
- def flops(self):
- H, W = self.input_resolution
- flops = H * W * self.num_feat * 3 * 9
- return flops
-
-
-
-class Swin2SR(nn.Module):
- r""" Swin2SR
- A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
-
- Args:
- img_size (int | tuple(int)): Input image size. Default 64
- patch_size (int | tuple(int)): Patch size. Default: 1
- in_chans (int): Number of input image channels. Default: 3
- embed_dim (int): Patch embedding dimension. Default: 96
- depths (tuple(int)): Depth of each Swin Transformer layer.
- num_heads (tuple(int)): Number of attention heads in different layers.
- window_size (int): Window size. Default: 7
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
- drop_rate (float): Dropout rate. Default: 0
- attn_drop_rate (float): Attention dropout rate. Default: 0
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
- upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
- img_range: Image range. 1. or 255.
- upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
- resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
- """
-
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
- embed_dim=96, depths=(6, 6, 6, 6), num_heads=(6, 6, 6, 6),
- window_size=7, mlp_ratio=4., qkv_bias=True,
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
- use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
- **kwargs):
- super(Swin2SR, self).__init__()
- num_in_ch = in_chans
- num_out_ch = in_chans
- num_feat = 64
- self.img_range = img_range
- if in_chans == 3:
- rgb_mean = (0.4488, 0.4371, 0.4040)
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
- else:
- self.mean = torch.zeros(1, 1, 1, 1)
- self.upscale = upscale
- self.upsampler = upsampler
- self.window_size = window_size
-
- #####################################################################################################
- ################################### 1, shallow feature extraction ###################################
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
-
- #####################################################################################################
- ################################### 2, deep feature extraction ######################################
- self.num_layers = len(depths)
- self.embed_dim = embed_dim
- self.ape = ape
- self.patch_norm = patch_norm
- self.num_features = embed_dim
- self.mlp_ratio = mlp_ratio
-
- # split image into non-overlapping patches
- self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
- num_patches = self.patch_embed.num_patches
- patches_resolution = self.patch_embed.patches_resolution
- self.patches_resolution = patches_resolution
-
- # merge non-overlapping patches into image
- self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
-
- # absolute position embedding
- if self.ape:
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
-
- self.pos_drop = nn.Dropout(p=drop_rate)
-
- # stochastic depth
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
-
- # build Residual Swin Transformer blocks (RSTB)
- self.layers = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers.append(layer)
-
- if self.upsampler == 'pixelshuffle_hf':
- self.layers_hf = nn.ModuleList()
- for i_layer in range(self.num_layers):
- layer = RSTB(dim=embed_dim,
- input_resolution=(patches_resolution[0],
- patches_resolution[1]),
- depth=depths[i_layer],
- num_heads=num_heads[i_layer],
- window_size=window_size,
- mlp_ratio=self.mlp_ratio,
- qkv_bias=qkv_bias,
- drop=drop_rate, attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
- norm_layer=norm_layer,
- downsample=None,
- use_checkpoint=use_checkpoint,
- img_size=img_size,
- patch_size=patch_size,
- resi_connection=resi_connection
-
- )
- self.layers_hf.append(layer)
-
- self.norm = norm_layer(self.num_features)
-
- # build the last conv layer in deep feature extraction
- if resi_connection == '1conv':
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- elif resi_connection == '3conv':
- # to save parameters and memory
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
-
- #####################################################################################################
- ################################ 3, high quality image reconstruction ################################
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- elif self.upsampler == 'pixelshuffle_aux':
- self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
- self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_after_aux = nn.Sequential(
- nn.Conv2d(3, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffle_hf':
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.upsample = Upsample(upscale, num_feat)
- self.upsample_hf = Upsample_hf(upscale, num_feat)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
- self.conv_before_upsample_hf = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
- elif self.upsampler == 'nearest+conv':
- # for real-world SR (less artifacts)
- assert self.upscale == 4, 'only support x4 now.'
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
- nn.LeakyReLU(inplace=True))
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- else:
- # for image denoising and JPEG compression artifact reduction
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
-
- self.apply(self._init_weights)
-
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
-
- @torch.jit.ignore
- def no_weight_decay(self):
- return {'absolute_pos_embed'}
-
- @torch.jit.ignore
- def no_weight_decay_keywords(self):
- return {'relative_position_bias_table'}
-
- def check_image_size(self, x):
- _, _, h, w = x.size()
- mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
- mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
- return x
-
- def forward_features(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward_features_hf(self, x):
- x_size = (x.shape[2], x.shape[3])
- x = self.patch_embed(x)
- if self.ape:
- x = x + self.absolute_pos_embed
- x = self.pos_drop(x)
-
- for layer in self.layers_hf:
- x = layer(x, x_size)
-
- x = self.norm(x) # B L C
- x = self.patch_unembed(x, x_size)
-
- return x
-
- def forward(self, x):
- H, W = x.shape[2:]
- x = self.check_image_size(x)
-
- self.mean = self.mean.type_as(x)
- x = (x - self.mean) * self.img_range
-
- if self.upsampler == 'pixelshuffle':
- # for classical SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.conv_last(self.upsample(x))
- elif self.upsampler == 'pixelshuffle_aux':
- bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False)
- bicubic = self.conv_bicubic(bicubic)
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- aux = self.conv_aux(x) # b, 3, LR_H, LR_W
- x = self.conv_after_aux(aux)
- x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale]
- x = self.conv_last(x)
- aux = aux / self.img_range + self.mean
- elif self.upsampler == 'pixelshuffle_hf':
- # for classical SR with HF
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x_before = self.conv_before_upsample(x)
- x_out = self.conv_last(self.upsample(x_before))
-
- x_hf = self.conv_first_hf(x_before)
- x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
- x_hf = self.conv_before_upsample_hf(x_hf)
- x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
- x = x_out + x_hf
- x_hf = x_hf / self.img_range + self.mean
-
- elif self.upsampler == 'pixelshuffledirect':
- # for lightweight SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.upsample(x)
- elif self.upsampler == 'nearest+conv':
- # for real-world SR
- x = self.conv_first(x)
- x = self.conv_after_body(self.forward_features(x)) + x
- x = self.conv_before_upsample(x)
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
- else:
- # for image denoising and JPEG compression artifact reduction
- x_first = self.conv_first(x)
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
- x = x + self.conv_last(res)
-
- x = x / self.img_range + self.mean
- if self.upsampler == "pixelshuffle_aux":
- return x[:, :, :H*self.upscale, :W*self.upscale], aux
-
- elif self.upsampler == "pixelshuffle_hf":
- x_out = x_out / self.img_range + self.mean
- return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale]
-
- else:
- return x[:, :, :H*self.upscale, :W*self.upscale]
-
- def flops(self):
- flops = 0
- H, W = self.patches_resolution
- flops += H * W * 3 * self.embed_dim * 9
- flops += self.patch_embed.flops()
- for layer in self.layers:
- flops += layer.flops()
- flops += H * W * 3 * self.embed_dim * self.embed_dim
- flops += self.upsample.flops()
- return flops
-
-
-if __name__ == '__main__':
- upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
- model = Swin2SR(upscale=2, img_size=(height, width),
- window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
- x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index ac2c3de4..4c10d9c7 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,7 +1,7 @@ import math
import gradio as gr
-from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste
+from modules import scripts, shared, ui_components, ui_settings, infotext_utils
from modules.ui_components import FormColumn
@@ -25,7 +25,7 @@ class ExtraOptionsSection(scripts.Script): extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img
elem_id_tabname = "extra_options_" + ("img2img" if is_img2img else "txt2img")
- mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping}
+ mapping = {k: v for v, k in infotext_utils.infotext_to_setting_name_mapping}
with gr.Blocks() as interface:
with gr.Accordion("Options", open=False, elem_id=elem_id_tabname) if shared.opts.extra_options_accordion and extra_options else gr.Group(elem_id=elem_id_tabname):
diff --git a/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py new file mode 100644 index 00000000..d9024344 --- /dev/null +++ b/extensions-builtin/soft-inpainting/scripts/soft_inpainting.py @@ -0,0 +1,747 @@ +import numpy as np +import gradio as gr +import math +from modules.ui_components import InputAccordion +import modules.scripts as scripts + + +class SoftInpaintingSettings: + def __init__(self, + mask_blend_power, + mask_blend_scale, + inpaint_detail_preservation, + composite_mask_influence, + composite_difference_threshold, + composite_difference_contrast): + self.mask_blend_power = mask_blend_power + self.mask_blend_scale = mask_blend_scale + self.inpaint_detail_preservation = inpaint_detail_preservation + self.composite_mask_influence = composite_mask_influence + self.composite_difference_threshold = composite_difference_threshold + self.composite_difference_contrast = composite_difference_contrast + + def add_generation_params(self, dest): + dest[enabled_gen_param_label] = True + dest[gen_param_labels.mask_blend_power] = self.mask_blend_power + dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale + dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation + dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence + dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold + dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast + + +# ------------------- Methods ------------------- + +def processing_uses_inpainting(p): + # TODO: Figure out a better way to determine if inpainting is being used by p + if getattr(p, "image_mask", None) is not None: + return True + + if getattr(p, "mask", None) is not None: + return True + + if getattr(p, "nmask", None) is not None: + return True + + return False + + +def latent_blend(settings, a, b, t): + """ + Interpolates two latent image representations according to the parameter t, + where the interpolated vectors' magnitudes are also interpolated separately. + The "detail_preservation" factor biases the magnitude interpolation towards + the larger of the two magnitudes. + """ + import torch + + # NOTE: We use inplace operations wherever possible. + + # [4][w][h] to [1][4][w][h] + t2 = t.unsqueeze(0) + # [4][w][h] to [1][1][w][h] - the [4] seem redundant. + t3 = t[0].unsqueeze(0).unsqueeze(0) + + one_minus_t2 = 1 - t2 + one_minus_t3 = 1 - t3 + + # Linearly interpolate the image vectors. + a_scaled = a * one_minus_t2 + b_scaled = b * t2 + image_interp = a_scaled + image_interp.add_(b_scaled) + result_type = image_interp.dtype + del a_scaled, b_scaled, t2, one_minus_t2 + + # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.) + # 64-bit operations are used here to allow large exponents. + current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001) + + # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1). + a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + settings.inpaint_detail_preservation) * one_minus_t3 + b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_( + settings.inpaint_detail_preservation) * t3 + desired_magnitude = a_magnitude + desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation) + del a_magnitude, b_magnitude, t3, one_minus_t3 + + # Change the linearly interpolated image vectors' magnitudes to the value we want. + # This is the last 64-bit operation. + image_interp_scaling_factor = desired_magnitude + image_interp_scaling_factor.div_(current_magnitude) + image_interp_scaling_factor = image_interp_scaling_factor.to(result_type) + image_interp_scaled = image_interp + image_interp_scaled.mul_(image_interp_scaling_factor) + del current_magnitude + del desired_magnitude + del image_interp + del image_interp_scaling_factor + del result_type + + return image_interp_scaled + + +def get_modified_nmask(settings, nmask, sigma): + """ + Converts a negative mask representing the transparency of the original latent vectors being overlayed + to a mask that is scaled according to the denoising strength for this step. + + Where: + 0 = fully opaque, infinite density, fully masked + 1 = fully transparent, zero density, fully unmasked + + We bring this transparency to a power, as this allows one to simulate N number of blending operations + where N can be any positive real value. Using this one can control the balance of influence between + the denoiser and the original latents according to the sigma value. + + NOTE: "mask" is not used + """ + import torch + return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale) + + +def apply_adaptive_masks( + settings: SoftInpaintingSettings, + nmask, + latent_orig, + latent_processed, + overlay_images, + width, height, + paste_to): + import torch + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control. + latent_mask = nmask[0].float() + # convert the original mask into a form we use to scale distances for thresholding + mask_scalar = 1 - (torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2)) + mask_scalar = (0.5 * (1 - settings.composite_mask_influence) + + mask_scalar * settings.composite_mask_influence) + mask_scalar = mask_scalar / (1.00001 - mask_scalar) + mask_scalar = mask_scalar.cpu().numpy() + + latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1) + + kernel, kernel_center = get_gaussian_kernel(stddev_radius=1.5, max_radius=2) + + masks_for_overlay = [] + + for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)): + converted_mask = distance_map.float().cpu().numpy() + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.9, percentile_max=1, min_width=1) + converted_mask = weighted_histogram_filter(converted_mask, kernel, kernel_center, + percentile_min=0.25, percentile_max=0.75, min_width=1) + + # The distance at which opacity of original decreases to 50% + half_weighted_distance = settings.composite_difference_threshold * mask_scalar + converted_mask = converted_mask / half_weighted_distance + + converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast) + converted_mask = smootherstep(converted_mask) + converted_mask = 1 - converted_mask + converted_mask = 255. * converted_mask + converted_mask = converted_mask.astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (overlay_image.width, overlay_image.height), + paste_to) + + masks_for_overlay.append(converted_mask) + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def apply_masks( + settings, + nmask, + overlay_images, + width, height, + paste_to): + import torch + import modules.processing as proc + import modules.images as images + from PIL import Image, ImageOps, ImageFilter + + converted_mask = nmask[0].float() + converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2) + converted_mask = 255. * converted_mask + converted_mask = converted_mask.cpu().numpy().astype(np.uint8) + converted_mask = Image.fromarray(converted_mask) + converted_mask = images.resize_image(2, converted_mask, width, height) + converted_mask = proc.create_binary_mask(converted_mask, round=False) + + # Remove aliasing artifacts using a gaussian blur. + converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4)) + + # Expand the mask to fit the whole image if needed. + if paste_to is not None: + converted_mask = proc.uncrop(converted_mask, + (width, height), + paste_to) + + masks_for_overlay = [] + + for i, overlay_image in enumerate(overlay_images): + masks_for_overlay[i] = converted_mask + + image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height)) + image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(converted_mask.convert('L'))) + + overlay_images[i] = image_masked.convert('RGBA') + + return masks_for_overlay + + +def weighted_histogram_filter(img, kernel, kernel_center, percentile_min=0.0, percentile_max=1.0, min_width=1.0): + """ + Generalization convolution filter capable of applying + weighted mean, median, maximum, and minimum filters + parametrically using an arbitrary kernel. + + Args: + img (nparray): + The image, a 2-D array of floats, to which the filter is being applied. + kernel (nparray): + The kernel, a 2-D array of floats. + kernel_center (nparray): + The kernel center coordinate, a 1-D array with two elements. + percentile_min (float): + The lower bound of the histogram window used by the filter, + from 0 to 1. + percentile_max (float): + The upper bound of the histogram window used by the filter, + from 0 to 1. + min_width (float): + The minimum size of the histogram window bounds, in weight units. + Must be greater than 0. + + Returns: + (nparray): A filtered copy of the input image "img", a 2-D array of floats. + """ + + # Converts an index tuple into a vector. + def vec(x): + return np.array(x) + + kernel_min = -kernel_center + kernel_max = vec(kernel.shape) - kernel_center + + def weighted_histogram_filter_single(idx): + idx = vec(idx) + min_index = np.maximum(0, idx + kernel_min) + max_index = np.minimum(vec(img.shape), idx + kernel_max) + window_shape = max_index - min_index + + class WeightedElement: + """ + An element of the histogram, its weight + and bounds. + """ + + def __init__(self, value, weight): + self.value: float = value + self.weight: float = weight + self.window_min: float = 0.0 + self.window_max: float = 1.0 + + # Collect the values in the image as WeightedElements, + # weighted by their corresponding kernel values. + values = [] + for window_tup in np.ndindex(tuple(window_shape)): + window_index = vec(window_tup) + image_index = window_index + min_index + centered_kernel_index = image_index - idx + kernel_index = centered_kernel_index + kernel_center + element = WeightedElement(img[tuple(image_index)], kernel[tuple(kernel_index)]) + values.append(element) + + def sort_key(x: WeightedElement): + return x.value + + values.sort(key=sort_key) + + # Calculate the height of the stack (sum) + # and each sample's range they occupy in the stack + sum = 0 + for i in range(len(values)): + values[i].window_min = sum + sum += values[i].weight + values[i].window_max = sum + + # Calculate what range of this stack ("window") + # we want to get the weighted average across. + window_min = sum * percentile_min + window_max = sum * percentile_max + window_width = window_max - window_min + + # Ensure the window is within the stack and at least a certain size. + if window_width < min_width: + window_center = (window_min + window_max) / 2 + window_min = window_center - min_width / 2 + window_max = window_center + min_width / 2 + + if window_max > sum: + window_max = sum + window_min = sum - min_width + + if window_min < 0: + window_min = 0 + window_max = min_width + + value = 0 + value_weight = 0 + + # Get the weighted average of all the samples + # that overlap with the window, weighted + # by the size of their overlap. + for i in range(len(values)): + if window_min >= values[i].window_max: + continue + if window_max <= values[i].window_min: + break + + s = max(window_min, values[i].window_min) + e = min(window_max, values[i].window_max) + w = e - s + + value += values[i].value * w + value_weight += w + + return value / value_weight if value_weight != 0 else 0 + + img_out = img.copy() + + # Apply the kernel operation over each pixel. + for index in np.ndindex(img.shape): + img_out[index] = weighted_histogram_filter_single(index) + + return img_out + + +def smoothstep(x): + """ + The smoothstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * (3 - 2 * x) + + +def smootherstep(x): + """ + The smootherstep function, input should be clamped to 0-1 range. + Turns a diagonal line (f(x) = x) into a sigmoid-like curve. + """ + return x * x * x * (x * (6 * x - 15) + 10) + + +def get_gaussian_kernel(stddev_radius=1.0, max_radius=2): + """ + Creates a Gaussian kernel with thresholded edges. + + Args: + stddev_radius (float): + Standard deviation of the gaussian kernel, in pixels. + max_radius (int): + The size of the filter kernel. The number of pixels is (max_radius*2+1) ** 2. + The kernel is thresholded so that any values one pixel beyond this radius + is weighted at 0. + + Returns: + (nparray, nparray): A kernel array (shape: (N, N)), its center coordinate (shape: (2)) + """ + + # Evaluates a 0-1 normalized gaussian function for a given square distance from the mean. + def gaussian(sqr_mag): + return math.exp(-sqr_mag / (stddev_radius * stddev_radius)) + + # Helper function for converting a tuple to an array. + def vec(x): + return np.array(x) + + """ + Since a gaussian is unbounded, we need to limit ourselves + to a finite range. + We taper the ends off at the end of that range so they equal zero + while preserving the maximum value of 1 at the mean. + """ + zero_radius = max_radius + 1.0 + gauss_zero = gaussian(zero_radius * zero_radius) + gauss_kernel_scale = 1 / (1 - gauss_zero) + + def gaussian_kernel_func(coordinate): + x = coordinate[0] ** 2.0 + coordinate[1] ** 2.0 + x = gaussian(x) + x -= gauss_zero + x *= gauss_kernel_scale + x = max(0.0, x) + return x + + size = max_radius * 2 + 1 + kernel_center = max_radius + kernel = np.zeros((size, size)) + + for index in np.ndindex(kernel.shape): + kernel[index] = gaussian_kernel_func(vec(index) - kernel_center) + + return kernel, kernel_center + + +# ------------------- Constants ------------------- + + +default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2) + +enabled_ui_label = "Soft inpainting" +enabled_gen_param_label = "Soft inpainting enabled" +enabled_el_id = "soft_inpainting_enabled" + +ui_labels = SoftInpaintingSettings( + "Schedule bias", + "Preservation strength", + "Transition contrast boost", + "Mask influence", + "Difference threshold", + "Difference contrast") + +ui_info = SoftInpaintingSettings( + "Shifts when preservation of original content occurs during denoising.", + "How strongly partially masked content should be preserved.", + "Amplifies the contrast that may be lost in partially masked regions.", + "How strongly the original mask should bias the difference threshold.", + "How much an image region can change before the original pixels are not blended in anymore.", + "How sharp the transition should be between blended and not blended.") + +gen_param_labels = SoftInpaintingSettings( + "Soft inpainting schedule bias", + "Soft inpainting preservation strength", + "Soft inpainting transition contrast boost", + "Soft inpainting mask influence", + "Soft inpainting difference threshold", + "Soft inpainting difference contrast") + +el_ids = SoftInpaintingSettings( + "mask_blend_power", + "mask_blend_scale", + "inpaint_detail_preservation", + "composite_mask_influence", + "composite_difference_threshold", + "composite_difference_contrast") + + +# ------------------- Script ------------------- + + +class Script(scripts.Script): + def __init__(self): + self.section = "inpaint" + self.masks_for_overlay = None + self.overlay_images = None + + def title(self): + return "Soft Inpainting" + + def show(self, is_img2img): + return scripts.AlwaysVisible if is_img2img else False + + def ui(self, is_img2img): + if not is_img2img: + return + + with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled: + with gr.Group(): + gr.Markdown( + """ + Soft inpainting allows you to **seamlessly blend original content with inpainted content** according to the mask opacity. + **High _Mask blur_** values are recommended! + """) + + power = \ + gr.Slider(label=ui_labels.mask_blend_power, + info=ui_info.mask_blend_power, + minimum=0, + maximum=8, + step=0.1, + value=default.mask_blend_power, + elem_id=el_ids.mask_blend_power) + scale = \ + gr.Slider(label=ui_labels.mask_blend_scale, + info=ui_info.mask_blend_scale, + minimum=0, + maximum=8, + step=0.05, + value=default.mask_blend_scale, + elem_id=el_ids.mask_blend_scale) + detail = \ + gr.Slider(label=ui_labels.inpaint_detail_preservation, + info=ui_info.inpaint_detail_preservation, + minimum=1, + maximum=32, + step=0.5, + value=default.inpaint_detail_preservation, + elem_id=el_ids.inpaint_detail_preservation) + + gr.Markdown( + """ + ### Pixel Composite Settings + """) + + mask_inf = \ + gr.Slider(label=ui_labels.composite_mask_influence, + info=ui_info.composite_mask_influence, + minimum=0, + maximum=1, + step=0.05, + value=default.composite_mask_influence, + elem_id=el_ids.composite_mask_influence) + + dif_thresh = \ + gr.Slider(label=ui_labels.composite_difference_threshold, + info=ui_info.composite_difference_threshold, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_threshold, + elem_id=el_ids.composite_difference_threshold) + + dif_contr = \ + gr.Slider(label=ui_labels.composite_difference_contrast, + info=ui_info.composite_difference_contrast, + minimum=0, + maximum=8, + step=0.25, + value=default.composite_difference_contrast, + elem_id=el_ids.composite_difference_contrast) + + with gr.Accordion("Help", open=False): + gr.Markdown( + f""" + ### {ui_labels.mask_blend_power} + + The blending strength of original content is scaled proportionally with the decreasing noise level values at each step (sigmas). + This ensures that the influence of the denoiser and original content preservation is roughly balanced at each step. + This balance can be shifted using this parameter, controlling whether earlier or later steps have stronger preservation. + + - **Below 1**: Stronger preservation near the end (with low sigma) + - **1**: Balanced (proportional to sigma) + - **Above 1**: Stronger preservation in the beginning (with high sigma) + """) + gr.Markdown( + f""" + ### {ui_labels.mask_blend_scale} + + Skews whether partially masked image regions should be more likely to preserve the original content or favor inpainted content. + This may need to be adjusted depending on the {ui_labels.mask_blend_power}, CFG Scale, prompt and Denoising strength. + + - **Low values**: Favors generated content. + - **High values**: Favors original content. + """) + gr.Markdown( + f""" + ### {ui_labels.inpaint_detail_preservation} + + This parameter controls how the original latent vectors and denoised latent vectors are interpolated. + With higher values, the magnitude of the resulting blended vector will be closer to the maximum of the two interpolated vectors. + This can prevent the loss of contrast that occurs with linear interpolation. + + - **Low values**: Softer blending, details may fade. + - **High values**: Stronger contrast, may over-saturate colors. + """) + + gr.Markdown( + """ + ## Pixel Composite Settings + + Masks are generated based on how much a part of the image changed after denoising. + These masks are used to blend the original and final images together. + If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_mask_influence} + + This parameter controls how much the mask should bias this sensitivity to difference. + + - **0**: Ignore the mask, only consider differences in image content. + - **1**: Follow the mask closely despite image content changes. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_threshold} + + This value represents the difference at which the original pixels will have less than 50% opacity. + + - **Low values**: Two images patches must be almost the same in order to retain original pixels. + - **High values**: Two images patches can be very different and still retain original pixels. + """) + + gr.Markdown( + f""" + ### {ui_labels.composite_difference_contrast} + + This value represents the contrast between the opacity of the original and inpainted content. + + - **Low values**: The blend will be more gradual and have longer transitions, but may cause ghosting. + - **High values**: Ghosting will be less common, but transitions may be very sudden. + """) + + self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label), + (power, gen_param_labels.mask_blend_power), + (scale, gen_param_labels.mask_blend_scale), + (detail, gen_param_labels.inpaint_detail_preservation), + (mask_inf, gen_param_labels.composite_mask_influence), + (dif_thresh, gen_param_labels.composite_difference_threshold), + (dif_contr, gen_param_labels.composite_difference_contrast)] + + self.paste_field_names = [] + for _, field_name in self.infotext_fields: + self.paste_field_names.append(field_name) + + return [soft_inpainting_enabled, + power, + scale, + detail, + mask_inf, + dif_thresh, + dif_contr] + + def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + # Shut off the rounding it normally does. + p.mask_round = False + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # p.extra_generation_params["Mask rounding"] = False + settings.add_generation_params(p.extra_generation_params) + + def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + if mba.is_final_blend: + mba.blended_latent = mba.current_latent + return + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # todo: Why is sigma 2D? Both values are the same. + mba.blended_latent = latent_blend(settings, + mba.init_latent, + mba.current_latent, + get_modified_nmask(settings, mba.nmask, mba.sigma[0])) + + def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, + dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + nmask = getattr(p, "nmask", None) + if nmask is None: + return + + from modules import images + from modules.shared import opts + + settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr) + + # since the original code puts holes in the existing overlay images, + # we have to rebuild them. + self.overlay_images = [] + for img in p.init_images: + + image = images.flatten(img, opts.img2img_background_color) + + if p.paste_to is None and p.resize_mode != 3: + image = images.resize_image(p.resize_mode, image, p.width, p.height) + + self.overlay_images.append(image.convert('RGBA')) + + if len(p.init_images) == 1: + self.overlay_images = self.overlay_images * p.batch_size + + if getattr(ps.samples, 'already_decoded', False): + self.masks_for_overlay = apply_masks(settings=settings, + nmask=nmask, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + else: + self.masks_for_overlay = apply_adaptive_masks(settings=settings, + nmask=nmask, + latent_orig=p.init_latent, + latent_processed=ps.samples, + overlay_images=self.overlay_images, + width=p.width, + height=p.height, + paste_to=p.paste_to) + + def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, + detail_preservation, mask_inf, dif_thresh, dif_contr): + if not enabled: + return + + if not processing_uses_inpainting(p): + return + + if self.masks_for_overlay is None: + return + + if self.overlay_images is None: + return + + ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index] + ppmo.overlay_image = self.overlay_images[ppmo.index] |