From 4c24347e45776d505937856ab280548d9298f0a8 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 24 Oct 2022 23:04:50 -0400 Subject: Remove BSRGAN from --use-cpu, add SwinIR --- modules/swinir_model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'modules/swinir_model.py') diff --git a/modules/swinir_model.py b/modules/swinir_model.py index baa02e3d..facd262d 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -7,8 +7,8 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url from tqdm import tqdm -from modules import modelloader -from modules.shared import cmd_opts, opts, device +from modules import modelloader, devices +from modules.shared import cmd_opts, opts from modules.swinir_model_arch import SwinIR as net from modules.swinir_model_arch_v2 import Swin2SR as net2 from modules.upscaler import Upscaler, UpscalerData @@ -42,7 +42,7 @@ class UpscalerSwinIR(Upscaler): model = self.load_model(model_file) if model is None: return img - model = model.to(device) + model = model.to(devices.device_swinir) img = upscale(img, model) try: torch.cuda.empty_cache() @@ -111,7 +111,7 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device) + img = img.unsqueeze(0).to(devices.device_swinir) with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old @@ -139,8 +139,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): stride = tile - tile_overlap h_idx_list = list(range(0, h - tile, stride)) + [h - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile] - E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) - W = torch.zeros_like(E, dtype=torch.half, device=device) + E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) + W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: for h_idx in h_idx_list: -- cgit v1.2.3 From faed465a0b1a7d19669568738c93e04907c10415 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 25 Oct 2022 02:01:57 -0400 Subject: MPS Upscalers Fix Get ESRGAN, SCUNet, and SwinIR working correctly on MPS by ensuring memory is contiguous for tensor views before sending to MPS device. --- modules/swinir_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/swinir_model.py') diff --git a/modules/swinir_model.py b/modules/swinir_model.py index facd262d..4253b66d 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -111,7 +111,7 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_swinir) + img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old -- cgit v1.2.3