From c9bb33dd43dbb9479ff1b70351df14508c89ac60 Mon Sep 17 00:00:00 2001 From: victorca25 Date: Sun, 30 Oct 2022 12:52:50 +0100 Subject: add resrgan 8x, allow use 1x and up to 8x extra models, move BSRGAN model, add nearest --- modules/esrgan_model.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a13cf6ac..c61669b4 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -50,6 +50,7 @@ def mod2normal(state_dict): def resrgan2normal(state_dict, nb=23): # this code is copied from https://github.com/victorca25/iNNfer if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: + re8x = 0 crt_net = {} items = [] for k, v in state_dict.items(): @@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23): crt_net['model.3.bias'] = state_dict['conv_up1.bias'] crt_net['model.6.weight'] = state_dict['conv_up2.weight'] crt_net['model.6.bias'] = state_dict['conv_up2.bias'] - crt_net['model.8.weight'] = state_dict['conv_hr.weight'] - crt_net['model.8.bias'] = state_dict['conv_hr.bias'] - crt_net['model.10.weight'] = state_dict['conv_last.weight'] - crt_net['model.10.bias'] = state_dict['conv_last.bias'] + + if 'conv_up3.weight' in state_dict: + # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py + re8x = 3 + crt_net['model.9.weight'] = state_dict['conv_up3.weight'] + crt_net['model.9.bias'] = state_dict['conv_up3.bias'] + + crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight'] + crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias'] + crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight'] + crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net return state_dict -- cgit v1.2.3 From abfa22c16fb3d9b1ed8d049c7b68e94d1cca5b82 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Nov 2022 19:25:43 -0500 Subject: Revert "MPS Upscalers Fix" This reverts commit 768b95394a8500da639b947508f78296524f1836. --- modules/devices.py | 9 --------- modules/esrgan_model.py | 2 +- modules/scunet_model.py | 3 ++- modules/swinir_model.py | 2 +- 4 files changed, 4 insertions(+), 12 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/devices.py b/modules/devices.py index 67165bf6..a87d0d4c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -94,12 +94,3 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): - return input_tensor.contiguous() if device.type == 'mps' else input_tensor - - -def mps_contiguous_to(input_tensor, device): - return mps_contiguous(input_tensor, device).to(device) diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index c61669b4..9a9c38f1 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -199,7 +199,7 @@ def upscale_without_tiling(model, img): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan) + img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/scunet_model.py b/modules/scunet_model.py index 59532274..36a996bf 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -54,8 +54,9 @@ class UpscalerScuNET(modules.upscaler.Upscaler): img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), device) + img = img.unsqueeze(0).to(device) + img = img.to(device) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 4253b66d..facd262d 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -111,7 +111,7 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) + img = img.unsqueeze(0).to(devices.device_swinir) with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old -- cgit v1.2.3