aboutsummaryrefslogtreecommitdiffstats
path: root/modules/upscaler_utils.py
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-12-27 09:04:33 +0000
committerAarni Koskela <akx@iki.fi>2023-12-30 14:24:01 +0000
commite472383acbb9e07dca311abe5fb16ee2675e410a (patch)
tree69591965d87134116235daa785d31f60b70791b4 /modules/upscaler_utils.py
parent12c6f37f8e4b1d1d643c9d8d5dfc763c3203c728 (diff)
downloadstable-diffusion-webui-gfx803-e472383acbb9e07dca311abe5fb16ee2675e410a.tar.gz
stable-diffusion-webui-gfx803-e472383acbb9e07dca311abe5fb16ee2675e410a.tar.bz2
stable-diffusion-webui-gfx803-e472383acbb9e07dca311abe5fb16ee2675e410a.zip
Refactor esrgan_upscale to more generic upscale_with_model
Diffstat (limited to 'modules/upscaler_utils.py')
-rw-r--r--modules/upscaler_utils.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py
new file mode 100644
index 00000000..8bdda51c
--- /dev/null
+++ b/modules/upscaler_utils.py
@@ -0,0 +1,66 @@
+import logging
+from typing import Callable
+
+import numpy as np
+import torch
+import tqdm
+from PIL import Image
+
+from modules import devices, images
+
+logger = logging.getLogger(__name__)
+
+
+def upscale_without_tiling(model, img: Image.Image):
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
+ img = torch.from_numpy(img).float()
+ img = img.unsqueeze(0).to(devices.device_esrgan)
+ with torch.no_grad():
+ output = model(img)
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output = 255. * np.moveaxis(output, 0, 2)
+ output = output.astype(np.uint8)
+ output = output[:, :, ::-1]
+ return Image.fromarray(output, 'RGB')
+
+
+def upscale_with_model(
+ model: Callable[[torch.Tensor], torch.Tensor],
+ img: Image.Image,
+ *,
+ tile_size: int,
+ tile_overlap: int = 0,
+ desc="tiled upscale",
+) -> Image.Image:
+ if tile_size <= 0:
+ logger.debug("Upscaling %s without tiling", img)
+ output = upscale_without_tiling(model, img)
+ logger.debug("=> %s", output)
+ return output
+
+ grid = images.split_grid(img, tile_size, tile_size, tile_overlap)
+ newtiles = []
+
+ with tqdm.tqdm(total=grid.tile_count, desc=desc) as p:
+ for y, h, row in grid.tiles:
+ newrow = []
+ for x, w, tile in row:
+ logger.debug("Tile (%d, %d) %s...", x, y, tile)
+ output = upscale_without_tiling(model, tile)
+ scale_factor = output.width // tile.width
+ logger.debug("=> %s (scale factor %s)", output, scale_factor)
+ newrow.append([x * scale_factor, w * scale_factor, output])
+ p.update(1)
+ newtiles.append([y * scale_factor, h * scale_factor, newrow])
+
+ newgrid = images.Grid(
+ newtiles,
+ tile_w=grid.tile_w * scale_factor,
+ tile_h=grid.tile_h * scale_factor,
+ image_w=grid.image_w * scale_factor,
+ image_h=grid.image_h * scale_factor,
+ overlap=grid.overlap * scale_factor,
+ )
+ return images.combine_grid(newgrid)