aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-12-31 06:41:49 +0000
committerGitHub <noreply@github.com>2023-12-31 06:41:49 +0000
commita84e842189f5599fd354147f72d1a9b9ed0716c8 (patch)
treeae0e5e9df369eb1cefa41ee76eb0e56fe945d192
parentce21840a042b9454a136372ab2971c1f21ec51e0 (diff)
parent6f86b62a1be7993073ba3a789d522e0b8870605a (diff)
downloadstable-diffusion-webui-gfx803-a84e842189f5599fd354147f72d1a9b9ed0716c8.tar.gz
stable-diffusion-webui-gfx803-a84e842189f5599fd354147f72d1a9b9ed0716c8.tar.bz2
stable-diffusion-webui-gfx803-a84e842189f5599fd354147f72d1a9b9ed0716c8.zip
Merge pull request #14476 from akx/dedupe-tiled-weighted-inference
Deduplicate tiled inference code from SwinIR/ScuNET
-rw-r--r--extensions-builtin/ScuNET/scripts/scunet_model.py55
-rw-r--r--extensions-builtin/SwinIR/scripts/swinir_model.py57
-rw-r--r--modules/upscaler_utils.py72
3 files changed, 87 insertions, 97 deletions
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py
index 5f3dd08b..f799cb76 100644
--- a/extensions-builtin/ScuNET/scripts/scunet_model.py
+++ b/extensions-builtin/ScuNET/scripts/scunet_model.py
@@ -3,12 +3,11 @@ import sys
import PIL.Image
import numpy as np
import torch
-from tqdm import tqdm
import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
-
from modules.shared import opts
+from modules.upscaler_utils import tiled_upscale_2
class UpscalerScuNET(modules.upscaler.Upscaler):
@@ -40,47 +39,6 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers.append(scaler_data2)
self.scalers = scalers
- @staticmethod
- @torch.no_grad()
- def tiled_inference(img, model):
- # test the image tile by tile
- h, w = img.shape[2:]
- tile = opts.SCUNET_tile
- tile_overlap = opts.SCUNET_tile_overlap
- if tile == 0:
- return model(img)
-
- device = devices.get_device_for('scunet')
- assert tile % 8 == 0, "tile size should be a multiple of window_size"
- sf = 1
-
- stride = tile - tile_overlap
- h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
- w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(1, 3, h * sf, w * sf, dtype=img.dtype, device=device)
- W = torch.zeros_like(E, dtype=devices.dtype, device=device)
-
- with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="ScuNET tiles") as pbar:
- for h_idx in h_idx_list:
-
- for w_idx in w_idx_list:
-
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
-
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
- pbar.update(1)
- output = E.div_(W)
-
- return output
-
def do_upscale(self, img: PIL.Image.Image, selected_file):
devices.torch_gc()
@@ -104,7 +62,16 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
_img[:, :, :h, :w] = torch_img # pad image
torch_img = _img
- torch_output = self.tiled_inference(torch_img, model).squeeze(0)
+ with torch.no_grad():
+ torch_output = tiled_upscale_2(
+ torch_img,
+ model,
+ tile_size=opts.SCUNET_tile,
+ tile_overlap=opts.SCUNET_tile_overlap,
+ scale=1,
+ device=devices.get_device_for('scunet'),
+ desc="ScuNET tiles",
+ ).squeeze(0)
torch_output = torch_output[:, :h * 1, :w * 1] # remove padding, if any
np_output: np.ndarray = torch_output.float().cpu().clamp_(0, 1).numpy()
del torch_img, torch_output
diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py
index 95c7ec64..8a555c79 100644
--- a/extensions-builtin/SwinIR/scripts/swinir_model.py
+++ b/extensions-builtin/SwinIR/scripts/swinir_model.py
@@ -4,11 +4,11 @@ import sys
import numpy as np
import torch
from PIL import Image
-from tqdm import tqdm
from modules import modelloader, devices, script_callbacks, shared
-from modules.shared import opts, state
+from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import tiled_upscale_2
SWINIR_MODEL_URL = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
@@ -110,14 +110,14 @@ def upscale(
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
- output = inference(
+ output = tiled_upscale_2(
img,
model,
- tile=tile,
+ tile_size=tile,
tile_overlap=tile_overlap,
- window_size=window_size,
scale=scale,
device=device,
+ desc="SwinIR tiles",
)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
@@ -129,53 +129,6 @@ def upscale(
return Image.fromarray(output, "RGB")
-def inference(
- img,
- model,
- *,
- tile: int,
- tile_overlap: int,
- window_size: int,
- scale: int,
- device,
-):
- # test the image tile by tile
- b, c, h, w = img.size()
- tile = min(tile, h, w)
- assert tile % window_size == 0, "tile size should be a multiple of window_size"
- sf = scale
-
- stride = tile - tile_overlap
- h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
- w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
- E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device).type_as(img)
- W = torch.zeros_like(E, dtype=devices.dtype, device=device)
-
- with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
- for h_idx in h_idx_list:
- if state.interrupted or state.skipped:
- break
-
- for w_idx in w_idx_list:
- if state.interrupted or state.skipped:
- break
-
- in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
- out_patch = model(in_patch)
- out_patch_mask = torch.ones_like(out_patch)
-
- E[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch)
- W[
- ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
- ].add_(out_patch_mask)
- pbar.update(1)
- output = E.div_(W)
-
- return output
-
-
def on_ui_settings():
import gradio as gr
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
index 174c9bc3..8e413854 100644
--- a/modules/upscaler_utils.py
+++ b/modules/upscaler_utils.py
@@ -6,7 +6,7 @@ import torch
import tqdm
from PIL import Image
-from modules import images
+from modules import images, shared
logger = logging.getLogger(__name__)
@@ -68,3 +68,73 @@ def upscale_with_model(
overlap=grid.overlap * scale_factor,
)
return images.combine_grid(newgrid)
+
+
+def tiled_upscale_2(
+ img,
+ model,
+ *,
+ tile_size: int,
+ tile_overlap: int,
+ scale: int,
+ device,
+ desc="Tiled upscale",
+):
+ # Alternative implementation of `upscale_with_model` originally used by
+ # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and
+ # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in
+ # Pillow space without weighting.
+ b, c, h, w = img.size()
+ tile_size = min(tile_size, h, w)
+
+ if tile_size <= 0:
+ logger.debug("Upscaling %s without tiling", img.shape)
+ return model(img)
+
+ stride = tile_size - tile_overlap
+ h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size]
+ w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size]
+ result = torch.zeros(
+ b,
+ c,
+ h * scale,
+ w * scale,
+ device=device,
+ ).type_as(img)
+ weights = torch.zeros_like(result)
+ logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape)
+ with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar:
+ for h_idx in h_idx_list:
+ if shared.state.interrupted or shared.state.skipped:
+ break
+
+ for w_idx in w_idx_list:
+ if shared.state.interrupted or shared.state.skipped:
+ break
+
+ in_patch = img[
+ ...,
+ h_idx : h_idx + tile_size,
+ w_idx : w_idx + tile_size,
+ ]
+ out_patch = model(in_patch)
+
+ result[
+ ...,
+ h_idx * scale : (h_idx + tile_size) * scale,
+ w_idx * scale : (w_idx + tile_size) * scale,
+ ].add_(out_patch)
+
+ out_patch_mask = torch.ones_like(out_patch)
+
+ weights[
+ ...,
+ h_idx * scale : (h_idx + tile_size) * scale,
+ w_idx * scale : (w_idx + tile_size) * scale,
+ ].add_(out_patch_mask)
+
+ pbar.update(1)
+
+ output = result.div_(weights)
+
+ return output