diff options
author | Aarni Koskela <akx@iki.fi> | 2024-01-03 20:39:12 +0000 |
---|---|---|
committer | Aarni Koskela <akx@iki.fi> | 2024-01-03 20:39:12 +0000 |
commit | 62470ee23443cb2ad3943a152ccae26a689c86e1 (patch) | |
tree | 855e6a5048895ce3e18af0a4d5db9ff3efc99c7e | |
parent | 3d31d5c27beb433fa37b30f135ec06a278a87630 (diff) | |
download | stable-diffusion-webui-gfx803-62470ee23443cb2ad3943a152ccae26a689c86e1.tar.gz stable-diffusion-webui-gfx803-62470ee23443cb2ad3943a152ccae26a689c86e1.tar.bz2 stable-diffusion-webui-gfx803-62470ee23443cb2ad3943a152ccae26a689c86e1.zip |
upscale_2: cast image to model's dtype
-rw-r--r-- | modules/upscaler_utils.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/modules/upscaler_utils.py b/modules/upscaler_utils.py index e4c63f09..5db74877 100644 --- a/modules/upscaler_utils.py +++ b/modules/upscaler_utils.py @@ -94,6 +94,7 @@ def tiled_upscale_2( tile_size: int, tile_overlap: int, scale: int, + device: torch.device, desc="Tiled upscale", ): # Alternative implementation of `upscale_with_model` originally used by @@ -101,9 +102,6 @@ def tiled_upscale_2( # weighting is done in PyTorch space, as opposed to `images.Grid` doing it in # Pillow space without weighting. - # Grab the device the model is on, and use it. - device = torch_utils.get_param(model).device - b, c, h, w = img.size() tile_size = min(tile_size, h, w) @@ -175,7 +173,8 @@ def upscale_2( """ Convenience wrapper around `tiled_upscale_2` that handles PIL images. """ - tensor = pil_image_to_torch_bgr(img).float().unsqueeze(0) # add batch dimension + param = torch_utils.get_param(model) + tensor = pil_image_to_torch_bgr(img).to(dtype=param.dtype).unsqueeze(0) # add batch dimension with torch.no_grad(): output = tiled_upscale_2( @@ -185,5 +184,6 @@ def upscale_2( tile_overlap=tile_overlap, scale=scale, desc=desc, + device=param.device, ) return torch_bgr_to_pil_image(output) |