From ad4de819c43997f2666b5bad95301f5c37f9018e Mon Sep 17 00:00:00 2001 From: victorca25 Date: Sun, 9 Oct 2022 13:02:12 +0200 Subject: update ESRGAN architecture and model to support all ESRGAN models in the DB, BSRGAN and real-ESRGAN models --- modules/esrgan_model.py | 190 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 128 insertions(+), 62 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 3970e6e4..a49e2258 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -5,68 +5,115 @@ import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url -import modules.esrgam_model_arch as arch +import modules.esrgan_model_arch as arch from modules import shared, modelloader, images, devices from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts -def fix_model_layers(crt_model, pretrained_net): - # this code is adapted from https://github.com/xinntao/ESRGAN - if 'conv_first.weight' in pretrained_net: - return pretrained_net - if 'model.0.weight' not in pretrained_net: - is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] - if is_realesrgan: - raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.") - else: - raise Exception("The file is not a ESRGAN model.") +def mod2normal(state_dict): + # this code is copied from https://github.com/victorca25/iNNfer + if 'conv_first.weight' in state_dict: + crt_net = {} + items = [] + for k, v in state_dict.items(): + items.append(k) + + crt_net['model.0.weight'] = state_dict['conv_first.weight'] + crt_net['model.0.bias'] = state_dict['conv_first.bias'] + + for k in items.copy(): + if 'RDB' in k: + ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[ori_k] = state_dict[k] + items.remove(k) + + crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] + crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] + crt_net['model.3.weight'] = state_dict['upconv1.weight'] + crt_net['model.3.bias'] = state_dict['upconv1.bias'] + crt_net['model.6.weight'] = state_dict['upconv2.weight'] + crt_net['model.6.bias'] = state_dict['upconv2.bias'] + crt_net['model.8.weight'] = state_dict['HRconv.weight'] + crt_net['model.8.bias'] = state_dict['HRconv.bias'] + crt_net['model.10.weight'] = state_dict['conv_last.weight'] + crt_net['model.10.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net + return state_dict + + +def resrgan2normal(state_dict, nb=23): + # this code is copied from https://github.com/victorca25/iNNfer + if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: + crt_net = {} + items = [] + for k, v in state_dict.items(): + items.append(k) + + crt_net['model.0.weight'] = state_dict['conv_first.weight'] + crt_net['model.0.bias'] = state_dict['conv_first.bias'] + + for k in items.copy(): + if "rdb" in k: + ori_k = k.replace('body.', 'model.1.sub.') + ori_k = ori_k.replace('.rdb', '.RDB') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[ori_k] = state_dict[k] + items.remove(k) + + crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight'] + crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias'] + crt_net['model.3.weight'] = state_dict['conv_up1.weight'] + crt_net['model.3.bias'] = state_dict['conv_up1.bias'] + crt_net['model.6.weight'] = state_dict['conv_up2.weight'] + crt_net['model.6.bias'] = state_dict['conv_up2.bias'] + crt_net['model.8.weight'] = state_dict['conv_hr.weight'] + crt_net['model.8.bias'] = state_dict['conv_hr.bias'] + crt_net['model.10.weight'] = state_dict['conv_last.weight'] + crt_net['model.10.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net + return state_dict + + +def infer_params(state_dict): + # this code is copied from https://github.com/victorca25/iNNfer + scale2x = 0 + scalemin = 6 + n_uplayer = 0 + plus = False + + for block in list(state_dict): + parts = block.split(".") + n_parts = len(parts) + if n_parts == 5 and parts[2] == "sub": + nb = int(parts[3]) + elif n_parts == 3: + part_num = int(parts[1]) + if (part_num > scalemin + and parts[0] == "model" + and parts[2] == "weight"): + scale2x += 1 + if part_num > n_uplayer: + n_uplayer = part_num + out_nc = state_dict[block].shape[0] + if not plus and "conv1x1" in block: + plus = True + + nf = state_dict["model.0.weight"].shape[0] + in_nc = state_dict["model.0.weight"].shape[1] + out_nc = out_nc + scale = 2 ** scale2x + + return in_nc, out_nc, nf, nb, plus, scale - crt_net = crt_model.state_dict() - load_net_clean = {} - for k, v in pretrained_net.items(): - if k.startswith('module.'): - load_net_clean[k[7:]] = v - else: - load_net_clean[k] = v - pretrained_net = load_net_clean - - tbd = [] - for k, v in crt_net.items(): - tbd.append(k) - - # directly copy - for k, v in crt_net.items(): - if k in pretrained_net and pretrained_net[k].size() == v.size(): - crt_net[k] = pretrained_net[k] - tbd.remove(k) - - crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] - crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] - - for k in tbd.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[k] = pretrained_net[ori_k] - tbd.remove(k) - - crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] - crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] - crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] - crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] - crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] - crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] - crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] - crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] - crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] - crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] - - return crt_net class UpscalerESRGAN(Upscaler): def __init__(self, dirname): @@ -109,20 +156,39 @@ class UpscalerESRGAN(Upscaler): print("Unable to load %s from %s" % (self.model_path, filename)) return None - pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) + state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) + + if "params_ema" in state_dict: + state_dict = state_dict["params_ema"] + elif "params" in state_dict: + state_dict = state_dict["params"] + num_conv = 16 if "realesr-animevideov3" in filename else 32 + model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu') + model.load_state_dict(state_dict) + model.eval() + return model + + if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict: + nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23 + state_dict = resrgan2normal(state_dict, nb) + elif "conv_first.weight" in state_dict: + state_dict = mod2normal(state_dict) + elif "model.0.weight" not in state_dict: + raise Exception("The file is not a recognized ESRGAN model.") + + in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) - pretrained_net = fix_model_layers(crt_model, pretrained_net) - crt_model.load_state_dict(pretrained_net) - crt_model.eval() + model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus) + model.load_state_dict(state_dict) + model.eval() - return crt_model + return model def upscale_without_tiling(model, img): img = np.array(img) img = img[:, :, ::-1] - img = np.moveaxis(img, 2, 0) / 255 + img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): -- cgit v1.2.3 From faed465a0b1a7d19669568738c93e04907c10415 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 25 Oct 2022 02:01:57 -0400 Subject: MPS Upscalers Fix Get ESRGAN, SCUNet, and SwinIR working correctly on MPS by ensuring memory is contiguous for tensor views before sending to MPS device. --- modules/devices.py | 4 ++++ modules/esrgan_model.py | 2 +- modules/scunet_model.py | 3 +-- modules/swinir_model.py | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/devices.py b/modules/devices.py index 033a42d5..7511e1dc 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -81,3 +81,7 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") + +# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 +def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor +def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a49e2258..a13cf6ac 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -190,7 +190,7 @@ def upscale_without_tiling(model, img): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) + img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/scunet_model.py b/modules/scunet_model.py index 36a996bf..59532274 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -54,9 +54,8 @@ class UpscalerScuNET(modules.upscaler.Upscaler): img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(device) + img = devices.mps_contiguous_to(img.unsqueeze(0), device) - img = img.to(device) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/swinir_model.py b/modules/swinir_model.py index facd262d..4253b66d 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -111,7 +111,7 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_swinir) + img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old -- cgit v1.2.3 From c9bb33dd43dbb9479ff1b70351df14508c89ac60 Mon Sep 17 00:00:00 2001 From: victorca25 Date: Sun, 30 Oct 2022 12:52:50 +0100 Subject: add resrgan 8x, allow use 1x and up to 8x extra models, move BSRGAN model, add nearest --- modules/esrgan_model.py | 17 +++++++++++++---- modules/modelloader.py | 3 +++ modules/ui.py | 2 +- modules/upscaler.py | 17 ++++++++++++++++- 4 files changed, 33 insertions(+), 6 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a13cf6ac..c61669b4 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -50,6 +50,7 @@ def mod2normal(state_dict): def resrgan2normal(state_dict, nb=23): # this code is copied from https://github.com/victorca25/iNNfer if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: + re8x = 0 crt_net = {} items = [] for k, v in state_dict.items(): @@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23): crt_net['model.3.bias'] = state_dict['conv_up1.bias'] crt_net['model.6.weight'] = state_dict['conv_up2.weight'] crt_net['model.6.bias'] = state_dict['conv_up2.bias'] - crt_net['model.8.weight'] = state_dict['conv_hr.weight'] - crt_net['model.8.bias'] = state_dict['conv_hr.bias'] - crt_net['model.10.weight'] = state_dict['conv_last.weight'] - crt_net['model.10.bias'] = state_dict['conv_last.bias'] + + if 'conv_up3.weight' in state_dict: + # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py + re8x = 3 + crt_net['model.9.weight'] = state_dict['conv_up3.weight'] + crt_net['model.9.bias'] = state_dict['conv_up3.bias'] + + crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight'] + crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias'] + crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight'] + crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias'] + state_dict = crt_net return state_dict diff --git a/modules/modelloader.py b/modules/modelloader.py index b0f2f33d..e4a6f8ac 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -85,6 +85,9 @@ def cleanup_models(): src_path = os.path.join(root_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN") move_files(src_path, dest_path) + src_path = os.path.join(models_path, "BSRGAN") + dest_path = os.path.join(models_path, "ESRGAN") + move_files(src_path, dest_path, ".pth") src_path = os.path.join(root_path, "gfpgan") dest_path = os.path.join(models_path, "GFPGAN") move_files(src_path, dest_path) diff --git a/modules/ui.py b/modules/ui.py index 5055ca64..47610f5c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1059,7 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tabs(elem_id="extras_resize_mode"): with gr.TabItem('Scale by'): - upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) + upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4) with gr.TabItem('Scale to'): with gr.Group(): with gr.Row(): diff --git a/modules/upscaler.py b/modules/upscaler.py index 6ab2fb40..83fde7ca 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -10,6 +10,7 @@ import modules.shared from modules import modelloader, shared LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST) from modules.paths import models_path @@ -57,7 +58,7 @@ class Upscaler: dest_w = img.width * scale dest_h = img.height * scale for i in range(3): - if img.width >= dest_w and img.height >= dest_h: + if img.width > dest_w and img.height > dest_h: break img = self.do_upscale(img, selected_model) if img.width != dest_w or img.height != dest_h: @@ -120,3 +121,17 @@ class UpscalerLanczos(Upscaler): self.name = "Lanczos" self.scalers = [UpscalerData("Lanczos", None, self)] + +class UpscalerNearest(Upscaler): + scalers = [] + + def do_upscale(self, img, selected_model=None): + return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST) + + def load_model(self, _): + pass + + def __init__(self, dirname=None): + super().__init__(False) + self.name = "Nearest" + self.scalers = [UpscalerData("Nearest", None, self)] \ No newline at end of file -- cgit v1.2.3 From abfa22c16fb3d9b1ed8d049c7b68e94d1cca5b82 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 7 Nov 2022 19:25:43 -0500 Subject: Revert "MPS Upscalers Fix" This reverts commit 768b95394a8500da639b947508f78296524f1836. --- modules/devices.py | 9 --------- modules/esrgan_model.py | 2 +- modules/scunet_model.py | 3 ++- modules/swinir_model.py | 2 +- 4 files changed, 4 insertions(+), 12 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/devices.py b/modules/devices.py index 67165bf6..a87d0d4c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -94,12 +94,3 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") - - -# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): - return input_tensor.contiguous() if device.type == 'mps' else input_tensor - - -def mps_contiguous_to(input_tensor, device): - return mps_contiguous(input_tensor, device).to(device) diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index c61669b4..9a9c38f1 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -199,7 +199,7 @@ def upscale_without_tiling(model, img): img = img[:, :, ::-1] img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan) + img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/scunet_model.py b/modules/scunet_model.py index 59532274..36a996bf 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -54,8 +54,9 @@ class UpscalerScuNET(modules.upscaler.Upscaler): img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), device) + img = img.unsqueeze(0).to(device) + img = img.to(device) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/swinir_model.py b/modules/swinir_model.py index 4253b66d..facd262d 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -111,7 +111,7 @@ def upscale( img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir) + img = img.unsqueeze(0).to(devices.device_swinir) with torch.no_grad(), precision_scope("cuda"): _, _, h_old, w_old = img.size() h_pad = (h_old // window_size + 1) * window_size - h_old -- cgit v1.2.3