From dfdc51246c678b585e1bdfdb7d2f202b0ca0e362 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:38:13 +0200 Subject: SwinIR: use prefer_half --- extensions-builtin/SwinIR/scripts/swinir_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index bc427fea..6a8e21b0 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,6 +1,7 @@ import logging import sys +import torch from PIL import Image from modules import devices, modelloader, script_callbacks, shared, upscaler_utils @@ -69,7 +70,7 @@ class UpscalerSwinIR(Upscaler): model_descriptor = modelloader.load_spandrel_model( filename, device=self._get_device(), - dtype=devices.dtype, + prefer_half=(devices.dtype == torch.float16), expected_architecture="SwinIR", ) if getattr(shared.opts, 'SWIN_torch_compile', False): -- cgit v1.2.3 From 3d31d5c27beb433fa37b30f135ec06a278a87630 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:38:49 +0200 Subject: SwinIR: pass model.scale --- extensions-builtin/SwinIR/scripts/swinir_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 6a8e21b0..16bf9b79 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -51,7 +51,7 @@ class UpscalerSwinIR(Upscaler): model, tile_size=shared.opts.SWIN_tile, tile_overlap=shared.opts.SWIN_tile_overlap, - scale=4, # TODO: This was hard-coded before too... + scale=model.scale, desc="SwinIR", ) devices.torch_gc() -- cgit v1.2.3 From 62470ee23443cb2ad3943a152ccae26a689c86e1 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 3 Jan 2024 22:39:12 +0200 Subject: upscale_2: cast image to model's dtype --- modules/upscaler_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index e4c63f09..5db74877 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -94,6 +94,7 @@ def tiled_upscale_2( tile_size: int, tile_overlap: int, scale: int, + device: torch.device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by @@ -101,9 +102,6 @@ def tiled_upscale_2( # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. - # Grab the device the model is on, and use it. - device = torch_utils.get_param(model).device - b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -175,7 +173,8 @@ def upscale_2( """ Convenience wrapper around `tiled_upscale_2` that handles PIL images. """ - tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + param = torch_utils.get_param(model) + tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension with torch.no_grad(): output = tiled_upscale_2( @@ -185,5 +184,6 @@ def upscale_2( tile_overlap=tile_overlap, scale=scale, desc=desc, + device=param.device, ) return torch_bgr_to_pil_image(output) -- cgit v1.2.3