From e472383acbb9e07dca311abe5fb16ee2675e410a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Dec 2023 11:04:33 +0200 Subject: Refactor esrgan_upscale to more generic upscale_with_model --- modules/upscaler_utils.py | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 modules/upscaler_utils.py (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py new file mode 100644 index 00000000..8bdda51c --- /dev/null +++ b/modules/upscaler_utils.py @@ -0,0 +1,66 @@ +import logging +from typing import Callable + +import numpy as np +import torch +import tqdm +from PIL import Image + +from modules import devices, images + +logger = logging.getLogger(__name__) + + +def upscale_without_tiling(model, img: Image.Image): + img = np.array(img) + img = img[:, :, ::-1] + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(devices.device_esrgan) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + return Image.fromarray(output, 'RGB') + + +def upscale_with_model( + model: Callable[[torch.Tensor], torch.Tensor], + img: Image.Image, + *, + tile_size: int, + tile_overlap: int = 0, + desc="tiled upscale", +) -> Image.Image: + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img) + output = upscale_without_tiling(model, img) + logger.debug("=> %s", output) + return output + + grid = images.split_grid(img, tile_size, tile_size, tile_overlap) + newtiles = [] + + with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + for y, h, row in grid.tiles: + newrow = [] + for x, w, tile in row: + logger.debug("Tile (%d, %d) %s...", x, y, tile) + output = upscale_without_tiling(model, tile) + scale_factor = output.width // tile.width + logger.debug("=> %s (scale factor %s)", output, scale_factor) + newrow.append([x * scale_factor, w * scale_factor, output]) + p.update(1) + newtiles.append([y * scale_factor, h * scale_factor, newrow]) + + newgrid = images.Grid( + newtiles, + tile_w=grid.tile_w * scale_factor, + tile_h=grid.tile_h * scale_factor, + image_w=grid.image_w * scale_factor, + image_h=grid.image_h * scale_factor, + overlap=grid.overlap * scale_factor, + ) + return images.combine_grid(newgrid) -- cgit v1.2.3 From 8100e901ab0c5b04d289eebb722c8a653b8beef1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Dec 2023 22:41:53 +0300 Subject: fix error with RealESRGAN model failing to upscale fp32 image --- modules/upscaler_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8bdda51c..39f78a0b 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -16,9 +16,13 @@ def upscale_without_tiling(model, img: Image.Image): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) + + model_weight = next(iter(model.parameters())) + img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + with torch.no_grad(): output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = 255. * np.moveaxis(output, 0, 2) output = output.astype(np.uint8) -- cgit v1.2.3 From 3be90740316f8fbb950b31d440458a5e8ed4beb3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 00:43:41 +0300 Subject: fix for the previous fix. --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 39f78a0b..dde5d7ad 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -17,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.parameters())) + model_weight = next(iter(model.model.parameters())) img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) with torch.no_grad(): -- cgit v1.2.3 From 777af661a21821994993df3ef566b01df2bb61a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:09:51 +0200 Subject: Be more clear about Spandrel model nomenclature --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index dde5d7ad..174c9bc3 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import devices, images +from modules import images logger = logging.getLogger(__name__) -- cgit v1.2.3 From 6f86b62a1be7993073ba3a789d522e0b8870605a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 22:53:49 +0200 Subject: Deduplicate tiled inference code from SwinIR/ScuNET --- modules/upscaler_utils.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 174c9bc3..8e413854 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,7 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images +from modules import images, shared logger = logging.getLogger(__name__) @@ -68,3 +68,73 @@ def upscale_with_model( overlap=grid.overlap * scale_factor, ) return images.combine_grid(newgrid) + + +def tiled_upscale_2( + img, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + device, + desc="Tiled upscale", +): + # Alternative implementation of `upscale_with_model` originally used by + # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and + # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in + # Pillow space without weighting. + b, c, h, w = img.size() + tile_size = min(tile_size, h, w) + + if tile_size <= 0: + logger.debug("Upscaling %s without tiling", img.shape) + return model(img) + + stride = tile_size - tile_overlap + h_idx_list = list(range(0, h - tile_size, stride)) + [h - tile_size] + w_idx_list = list(range(0, w - tile_size, stride)) + [w - tile_size] + result = torch.zeros( + b, + c, + h * scale, + w * scale, + device=device, + ).type_as(img) + weights = torch.zeros_like(result) + logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + for h_idx in h_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + for w_idx in w_idx_list: + if shared.state.interrupted or shared.state.skipped: + break + + in_patch = img[ + ..., + h_idx : h_idx + tile_size, + w_idx : w_idx + tile_size, + ] + out_patch = model(in_patch) + + result[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch) + + out_patch_mask = torch.ones_like(out_patch) + + weights[ + ..., + h_idx * scale : (h_idx + tile_size) * scale, + w_idx * scale : (w_idx + tile_size) * scale, + ].add_(out_patch_mask) + + pbar.update(1) + + output = result.div_(weights) + + return output -- cgit v1.2.3 From 5768afc776a66bb94e77a9c1daebeea58fa731d5 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:20:30 +0200 Subject: Add utility to inspect a model's parameters (to get dtype/device) --- modules/upscaler_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 8e413854..c60e3beb 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -7,6 +7,7 @@ import tqdm from PIL import Image from modules import images, shared +from modules.torch_utils import get_param logger = logging.getLogger(__name__) @@ -17,8 +18,8 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - model_weight = next(iter(model.model.parameters())) - img = img.unsqueeze(0).to(device=model_weight.device, dtype=model_weight.dtype) + param = get_param(model) + img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): output = model(img) -- cgit v1.2.3 From a70dfb64a86b9b6d869deffdb0ffebe980365473 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 31 Dec 2023 22:38:30 +0300 Subject: change import statements for #14478 --- modules/upscaler_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index c60e3beb..f5cb92d5 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -6,8 +6,7 @@ import torch import tqdm from PIL import Image -from modules import images, shared -from modules.torch_utils import get_param +from modules import images, shared, torch_utils logger = logging.getLogger(__name__) @@ -18,7 +17,7 @@ def upscale_without_tiling(model, img: Image.Image): img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - param = get_param(model) + param = torch_utils.get_param(model) img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): -- cgit v1.2.3 From 1341b2208185cd89b0019bda2df63b406ec0cb5e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 2 Jan 2024 06:47:26 +0300 Subject: add an option to hide upscaling progressbar --- modules/upscaler_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index f5cb92d5..9379f512 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -47,7 +47,7 @@ def upscale_with_model( grid = images.split_grid(img, tile_size, tile_size, tile_overlap) newtiles = [] - with tqdm.tqdm(total=grid.tile_count, desc=desc) as p: + with tqdm.tqdm(total=grid.tile_count, desc=desc, disable=not shared.opts.enable_upscale_progressbar) as p: for y, h, row in grid.tiles: newrow = [] for x, w, tile in row: @@ -103,7 +103,7 @@ def tiled_upscale_2( ).type_as(img) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) - with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc) as pbar: + with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: for h_idx in h_idx_list: if shared.state.interrupted or shared.state.skipped: break -- cgit v1.2.3 From cf14a6a7aaf8ccb40552990785d5c9e400d93610 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 16:11:18 +0200 Subject: Refactor upscale_2 helper out of ScuNET/SwinIR; make sure devices are right --- modules/upscaler_utils.py | 89 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 20 deletions(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index 9379f512..e4c63f09 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -11,23 +11,40 @@ from modules import images, shared, torch_utils logger = logging.getLogger(__name__) -def upscale_without_tiling(model, img: Image.Image): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - +def pil_image_to_torch_bgr(img: Image.Image) -> torch.Tensor: + img = np.array(img.convert("RGB")) + img = img[:, :, ::-1] # flip RGB to BGR + img = np.transpose(img, (2, 0, 1)) # HWC to CHW + img = np.ascontiguousarray(img) / 255 # Rescale to [0, 1] + return torch.from_numpy(img) + + +def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: + if tensor.ndim == 4: + # If we're given a tensor with a batch dimension, squeeze it out + # (but only if it's a batch of size 1). + if tensor.shape[0] != 1: + raise ValueError(f"{tensor.shape} does not describe a BCHW tensor") + tensor = tensor.squeeze(0) + assert tensor.ndim == 3, f"{tensor.shape} does not describe a CHW tensor" + # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? + arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp + arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale + arr = arr.astype(np.uint8) + arr = arr[:, :, ::-1] # flip BGR to RGB + return Image.fromarray(arr, "RGB") + + +def upscale_pil_patch(model, img: Image.Image) -> Image.Image: + """ + Upscale a given PIL image using the given model. + """ param = torch_utils.get_param(model) - img = img.unsqueeze(0).to(device=param.device, dtype=param.dtype) with torch.no_grad(): - output = model(img) - - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') + tensor = pil_image_to_torch_bgr(img).unsqueeze(0) # add batch dimension + tensor = tensor.to(device=param.device, dtype=param.dtype) + return torch_bgr_to_pil_image(model(tensor)) def upscale_with_model( @@ -40,7 +57,7 @@ def upscale_with_model( ) -> Image.Image: if tile_size <= 0: logger.debug("Upscaling %s without tiling", img) - output = upscale_without_tiling(model, img) + output = upscale_pil_patch(model, img) logger.debug("=> %s", output) return output @@ -52,7 +69,7 @@ def upscale_with_model( newrow = [] for x, w, tile in row: logger.debug("Tile (%d, %d) %s...", x, y, tile) - output = upscale_without_tiling(model, tile) + output = upscale_pil_patch(model, tile) scale_factor = output.width // tile.width logger.debug("=> %s (scale factor %s)", output, scale_factor) newrow.append([x * scale_factor, w * scale_factor, output]) @@ -71,19 +88,22 @@ def upscale_with_model( def tiled_upscale_2( - img, + img: torch.Tensor, model, *, tile_size: int, tile_overlap: int, scale: int, - device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by # SwinIR and ScuNET. It differs from `upscale_with_model` in that tiling and # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. + + # Grab the device the model is on, and use it. + device = torch_utils.get_param(model).device + b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -100,7 +120,8 @@ def tiled_upscale_2( h * scale, w * scale, device=device, - ).type_as(img) + dtype=img.dtype, + ) weights = torch.zeros_like(result) logger.debug("Upscaling %s to %s with tiles", img.shape, result.shape) with tqdm.tqdm(total=len(h_idx_list) * len(w_idx_list), desc=desc, disable=not shared.opts.enable_upscale_progressbar) as pbar: @@ -112,11 +133,13 @@ def tiled_upscale_2( if shared.state.interrupted or shared.state.skipped: break + # Only move this patch to the device if it's not already there. in_patch = img[ ..., h_idx : h_idx + tile_size, w_idx : w_idx + tile_size, - ] + ].to(device=device) + out_patch = model(in_patch) result[ @@ -138,3 +161,29 @@ def tiled_upscale_2( output = result.div_(weights) return output + + +def upscale_2( + img: Image.Image, + model, + *, + tile_size: int, + tile_overlap: int, + scale: int, + desc: str, +): + """ + Convenience wrapper around `tiled_upscale_2` that handles PIL images. + """ + tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + + with torch.no_grad(): + output = tiled_upscale_2( + tensor, + model, + tile_size=tile_size, + tile_overlap=tile_overlap, + scale=scale, + desc=desc, + ) + return torch_bgr_to_pil_image(output) -- cgit v1.2.3 From 7ad6899bf987a8ee615efbcfc99562457f89cd8b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 2 Jan 2024 17:14:05 +0200 Subject: torch_bgr_to_pil_image: round, don't truncate This matches what `realesrgan` does. --- modules/upscaler_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index e4c63f09..4f1417cf 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -30,7 +30,7 @@ def torch_bgr_to_pil_image(tensor: torch.Tensor) -> Image.Image: # TODO: is `tensor.float().cpu()...numpy()` the most efficient idiom? arr = tensor.float().cpu().clamp_(0, 1).numpy() # clamp arr = 255.0 * np.moveaxis(arr, 0, 2) # CHW to HWC, rescale - arr = arr.astype(np.uint8) + arr = arr.round().astype(np.uint8) arr = arr[:, :, ::-1] # flip BGR to RGB return Image.fromarray(arr, "RGB") -- cgit v1.2.3 From 62470ee23443cb2ad3943a152ccae26a689c86e1 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:39:12 +0200 Subject: upscale_2: cast image to model's dtype --- modules/upscaler_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/upscaler_utils.py') diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index e4c63f09..5db74877 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -94,6 +94,7 @@ def tiled_upscale_2( tile_size: int, tile_overlap: int, scale: int, + device: torch.device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by @@ -101,9 +102,6 @@ def tiled_upscale_2( # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. - # Grab the device the model is on, and use it. - device = torch_utils.get_param(model).device - b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -175,7 +173,8 @@ def upscale_2( """ Convenience wrapper around `tiled_upscale_2` that handles PIL images. """ - tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + param = torch_utils.get_param(model) + tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension with torch.no_grad(): output = tiled_upscale_2( @@ -185,5 +184,6 @@ def upscale_2( tile_overlap=tile_overlap, scale=scale, desc=desc, + device=param.device, ) return torch_bgr_to_pil_image(output) -- cgit v1.2.3