diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-12-31 06:41:49 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-31 06:41:49 +0000 |
commit | a84e842189f5599fd354147f72d1a9b9ed0716c8 (patch) | |
tree | ae0e5e9df369eb1cefa41ee76eb0e56fe945d192 /extensions-builtin | |
parent | ce21840a042b9454a136372ab2971c1f21ec51e0 (diff) | |
parent | 6f86b62a1be7993073ba3a789d522e0b8870605a (diff) | |
download | stable-diffusion-webui-gfx803-a84e842189f5599fd354147f72d1a9b9ed0716c8.tar.gz stable-diffusion-webui-gfx803-a84e842189f5599fd354147f72d1a9b9ed0716c8.tar.bz2 stable-diffusion-webui-gfx803-a84e842189f5599fd354147f72d1a9b9ed0716c8.zip |
Merge pull request #14476 from akx/dedupe-tiled-weighted-inference
Deduplicate tiled inference code from SwinIR/ScuNET
Diffstat (limited to 'extensions-builtin')
-rw-r--r-- | extensions-builtin/ScuNET/scripts/scunet_model.py | 55 | ||||
-rw-r--r-- | extensions-builtin/SwinIR/scripts/swinir_model.py | 57 |
2 files changed, 16 insertions, 96 deletions
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 5f3dd08b..f799cb76 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -3,12 +3,11 @@ import sys import PIL.Image import numpy as np import torch -from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors - from modules.shared import opts +from modules.upscaler_utils import tiled_upscale_2 class UpscalerScuNET(modules.upscaler.Upscaler): @@ -40,47 +39,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers.append(scaler_data2) self.scalers = scalers - @staticmethod - @torch.no_grad() - def tiled_inference(img, model): - # test the image tile by tile - h, w = img.shape[2:] - tile = opts.SCUNET_tile - tile_overlap = opts.SCUNET_tile_overlap - if tile == 0: - return model(img) - - device = devices.get_device_for('scunet') - assert tile % 8 == 0, "tile size should be a multiple of window_size" - sf = 1 - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar: - for h_idx in h_idx_list: - - for w_idx in w_idx_list: - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - def do_upscale(self, img: PIL.Image.Image, selected_file): devices.torch_gc() @@ -104,7 +62,16 @@ class UpscalerScuNET(modules.upscaler.Upscaler): _img[:, :, :h, :w] = torch_img # pad image torch_img = _img - torch_output = self.tiled_inference(torch_img, model).squeeze(0) + with torch.no_grad(): + torch_output = tiled_upscale_2( + torch_img, + model, + tile_size=opts.SCUNET_tile, + tile_overlap=opts.SCUNET_tile_overlap, + scale=1, + device=devices.get_device_for('scunet'), + desc="ScuNET tiles", + ).squeeze(0) torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy() del torch_img, torch_output diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 95c7ec64..8a555c79 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -4,11 +4,11 @@ import sys import numpy as np import torch from PIL import Image -from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared -from modules.shared import opts, state +from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import tiled_upscale_2 SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" @@ -110,14 +110,14 @@ def upscale( w_pad = (w_old // window_size + 1) * window_size - w_old img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] - output = inference( + output = tiled_upscale_2( img, model, - tile=tile, + tile_size=tile, tile_overlap=tile_overlap, - window_size=window_size, scale=scale, device=device, + desc="SwinIR tiles", ) output = output[..., : h_old * scale, : w_old * scale] output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() @@ -129,53 +129,6 @@ def upscale( return Image.fromarray(output, "RGB") -def inference( - img, - model, - *, - tile: int, - tile_overlap: int, - window_size: int, - scale: int, - device, -): - # test the image tile by tile - b, c, h, w = img.size() - tile = min(tile, h, w) - assert tile % window_size == 0, "tile size should be a multiple of window_size" - sf = scale - - stride = tile - tile_overlap - h_idx_list = list(range(0, h - tile, stride)) + [h - tile] - w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img) - W = torch.zeros_like(E, dtype=devices.dtype, device=device) - - with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: - for h_idx in h_idx_list: - if state.interrupted or state.skipped: - break - - for w_idx in w_idx_list: - if state.interrupted or state.skipped: - break - - in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] - out_patch = model(in_patch) - out_patch_mask = torch.ones_like(out_patch) - - E[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch) - W[ - ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf - ].add_(out_patch_mask) - pbar.update(1) - output = E.div_(W) - - return output - - def on_ui_settings(): import gradio as gr |