diff options
author | Aarni Koskela <akx@iki.fi> | 2023-12-27 09:04:33 +0000 |
---|---|---|
committer | Aarni Koskela <akx@iki.fi> | 2023-12-30 14:24:01 +0000 |
commit | e472383acbb9e07dca311abe5fb16ee2675e410a (patch) | |
tree | 69591965d87134116235daa785d31f60b70791b4 /modules | |
parent | 12c6f37f8e4b1d1d643c9d8d5dfc763c3203c728 (diff) | |
download | stable-diffusion-webui-gfx803-e472383acbb9e07dca311abe5fb16ee2675e410a.tar.gz stable-diffusion-webui-gfx803-e472383acbb9e07dca311abe5fb16ee2675e410a.tar.bz2 stable-diffusion-webui-gfx803-e472383acbb9e07dca311abe5fb16ee2675e410a.zip |
Refactor esrgan_upscale to more generic upscale_with_model
Diffstat (limited to 'modules')
-rw-r--r-- | modules/esrgan_model.py | 47 | ||||
-rw-r--r-- | modules/upscaler_utils.py | 66 |
2 files changed, 74 insertions, 39 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 02a1727d..c0d22a99 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,13 +1,12 @@ import sys
-import numpy as np
import torch
-from PIL import Image
import modules.esrgan_model_arch as arch
-from modules import modelloader, images, devices
+from modules import modelloader, devices
from modules.shared import opts
from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
def mod2normal(state_dict):
@@ -190,40 +189,10 @@ class UpscalerESRGAN(Upscaler): return model
-def upscale_without_tiling(model, img):
- img = np.array(img)
- img = img[:, :, ::-1]
- img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
- img = torch.from_numpy(img).float()
- img = img.unsqueeze(0).to(devices.device_esrgan)
- with torch.no_grad():
- output = model(img)
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
- output = 255. * np.moveaxis(output, 0, 2)
- output = output.astype(np.uint8)
- output = output[:, :, ::-1]
- return Image.fromarray(output, 'RGB')
-
-
def esrgan_upscale(model, img):
- if opts.ESRGAN_tile == 0:
- return upscale_without_tiling(model, img)
-
- grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
- newtiles = []
- scale_factor = 1
-
- for y, h, row in grid.tiles:
- newrow = []
- for tiledata in row:
- x, w, tile = tiledata
-
- output = upscale_without_tiling(model, tile)
- scale_factor = output.width // tile.width
-
- newrow.append([x * scale_factor, w * scale_factor, output])
- newtiles.append([y * scale_factor, h * scale_factor, newrow])
-
- newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = images.combine_grid(newgrid)
- return output
+ return upscale_with_model(
+ model,
+ img,
+ tile_size=opts.ESRGAN_tile,
+ tile_overlap=opts.ESRGAN_tile_overlap,
+ )
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py new file mode 100644 index 00000000..8bdda51c --- /dev/null +++ b/modules/upscaler_utils.py @@ -0,0 +1,66 @@ +import logging +from typing import Callable + +import numpy as np +import torch +import tqdm +from PIL import Image + +from modules import devices, images + +logger = logging.getLogger(__name__) + + +def upscale_without_tiling(model, img: Image.Image): + img = np.array(img) + img = img[:, :, ::-1] + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(devices.device_esrgan) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + return Image.fromarray(output, 'RGB') + + +def upscale_with_model( + model: Callable[[torch.Tensor], torch.Tensor], + img: Image.Image, + *, + tile_size: int, + tile_overlap: int = 0, + desc="tiled upscale", +) -> Image.Image: + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img) + output = upscale_without_tiling(model, img) + logger.debug("=> %s", output) + return output + + grid = images.split_grid(img, tile_size, tile_size, tile_overlap) + newtiles = [] + + with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + for y, h, row in grid.tiles: + newrow = [] + for x, w, tile in row: + logger.debug("Tile (%d, %d) %s...", x, y, tile) + output = upscale_without_tiling(model, tile) + scale_factor = output.width // tile.width + logger.debug("=> %s (scale factor %s)", output, scale_factor) + newrow.append([x * scale_factor, w * scale_factor, output]) + p.update(1) + newtiles.append([y * scale_factor, h * scale_factor, newrow]) + + newgrid = images.Grid( + newtiles, + tile_w=grid.tile_w * scale_factor, + tile_h=grid.tile_h * scale_factor, + image_w=grid.image_w * scale_factor, + image_h=grid.image_h * scale_factor, + overlap=grid.overlap * scale_factor, + ) + return images.combine_grid(newgrid) |