From b0f59342346b1c8b405f97c0e0bb01c6ae05c601 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 14:43:51 +0200 Subject: Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR) --- extensions-builtin/ScuNET/scripts/scunet_model.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) (limited to 'extensions-builtin/ScuNET/scripts/scunet_model.py') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 167d2f64..18cf8e1a 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -7,9 +7,7 @@ from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet -from modules.modelloader import load_file_from_url from modules.shared import opts @@ -120,17 +118,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): device = devices.get_device_for('scunet') if path.startswith("http"): # TODO: this doesn't use `path` at all? - filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") + filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) - model.load_state_dict(torch.load(filename), strict=True) - model.eval() - for _, v in model.named_parameters(): - v.requires_grad = False - model = model.to(device) - - return model + return modelloader.load_spandrel_model(filename, device=device) def on_ui_settings(): -- cgit v1.2.3 From 4ad0c0c0a805da4bac03cff86ea17c25a1291546 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 16:37:03 +0200 Subject: Verify architecture for loaded Spandrel models --- extensions-builtin/ScuNET/scripts/scunet_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'extensions-builtin/ScuNET/scripts/scunet_model.py') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 18cf8e1a..5f3dd08b 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - return modelloader.load_spandrel_model(filename, device=device) + return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet') def on_ui_settings(): -- cgit v1.2.3 From 6f86b62a1be7993073ba3a789d522e0b8870605a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 22:53:49 +0200 Subject: Deduplicate tiled inference code from SwinIR/ScuNET --- extensions-builtin/ScuNET/scripts/scunet_model.py | 55 +++++------------------ 1 file changed, 11 insertions(+), 44 deletions(-) (limited to 'extensions-builtin/ScuNET/scripts/scunet_model.py') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 5f3dd08b..f799cb76 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -3,12 +3,11 @@ import sys import PIL.Image import numpy as np import torch -from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors - from modules.shared import opts +from modules.upscaler_utils import tiled_upscale_2 class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,47 +39,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers.append(scaler_data2) self.scalers = scalers - @staticmethod - @torch.no_grad() - def tiled_inference(img, model): - # test the image tile by tile - h, w = img.shape[2:] - tile = opts.SCUNET_tile - tile_overlap = opts.SCUNET_tile_overlap - if tile == 0: - return model(img) - - device = devices.get_device_for('scunet') - assert tile % 8 == 0, "tile size should be a multiple of window_size" - sf = 1 - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: - for h_idx in h_idx_list: - - for w_idx in w_idx_list: - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - def do_upscale(self, img: PIL.Image.Image, selected_file): devices.torch_gc() @@ -104,7 +62,16 @@ class UpscalerScuNET(modules.upscaler.Upscaler): _img[:, :, :h, :w] = torch_img # pad image torch_img = _img - torch_output = self.tiled_inference(torch_img, model).squeeze(0) + with torch.no_grad(): + torch_output = tiled_upscale_2( + torch_img, + model, + tile_size=opts.SCUNET_tile, + tile_overlap=opts.SCUNET_tile_overlap, + scale=1, + device=devices.get_device_for('scunet'), + desc="ScuNET tiles", + ).squeeze(0) torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() del torch_img, torch_output -- cgit v1.2.3 From cf14a6a7aaf8ccb40552990785d5c9e400d93610 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 16:11:18 +0200 Subject: Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right --- extensions-builtin/ScuNET/scripts/scunet_model.py | 48 +++++------------------ 1 file changed, 10 insertions(+), 38 deletions(-) (limited to 'extensions-builtin/ScuNET/scripts/scunet_model.py') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index f799cb76..fe5e5a19 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,13 +1,9 @@ import sys import PIL.Image -import numpy as np -import torch import modules.upscaler -from modules import devices, modelloader, script_callbacks, errors -from modules.shared import opts -from modules.upscaler_utils import tiled_upscale_2 +from modules import devices, errors, modelloader, script_callbacks, shared, upscaler_utils class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,46 +36,23 @@ class UpscalerScuNET(modules.upscaler.Upscaler): self.scalers = scalers def do_upscale(self, img: PIL.Image.Image, selected_file): - devices.torch_gc() - try: model = self.load_model(selected_file) except Exception as e: print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) return img - device = devices.get_device_for('scunet') - tile = opts.SCUNET_tile - h, w = img.height, img.width - np_img = np.array(img) - np_img = np_img[:, :, ::-1] # RGB to BGR - np_img = np_img.transpose((2, 0, 1)) / 255 # HWC to CHW - torch_img = torch.from_numpy(np_img).float().unsqueeze(0).to(device) # type: ignore - - if tile > h or tile > w: - _img = torch.zeros(1, 3, max(h, tile), max(w, tile), dtype=torch_img.dtype, device=torch_img.device) - _img[:, :, :h, :w] = torch_img # pad image - torch_img = _img - - with torch.no_grad(): - torch_output = tiled_upscale_2( - torch_img, - model, - tile_size=opts.SCUNET_tile, - tile_overlap=opts.SCUNET_tile_overlap, - scale=1, - device=devices.get_device_for('scunet'), - desc="ScuNET tiles", - ).squeeze(0) - torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any - np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() - del torch_img, torch_output + img = upscaler_utils.upscale_2( + img, + model, + tile_size=shared.opts.SCUNET_tile, + tile_overlap=shared.opts.SCUNET_tile_overlap, + scale=1, # ScuNET is a denoising model, not an upscaler + desc='ScuNET', + ) devices.torch_gc() - - output = np_output.transpose((1, 2, 0)) # CHW to HWC - output = output[:, :, ::-1] # BGR to RGB - return PIL.Image.fromarray((output * 255).astype(np.uint8)) + return img def load_model(self, path: str): device = devices.get_device_for('scunet') @@ -93,7 +66,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def on_ui_settings(): import gradio as gr - from modules import shared shared.opts.add_option("SCUNET_tile", shared.OptionInfo(256, "Tile size for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")).info("0 = no tiling")) shared.opts.add_option("SCUNET_tile_overlap", shared.OptionInfo(8, "Tile overlap for SCUNET upscalers.", gr.Slider, {"minimum": 0, "maximum": 64, "step": 1}, section=('upscaling', "Upscaling")).info("Low values = visible seam")) -- cgit v1.2.3