From e472383acbb9e07dca311abe5fb16ee2675e410a Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 27 Dec 2023 11:04:33 +0200 Subject: Refactor esrgan_upscale to more generic upscale_with_model --- modules/esrgan_model.py | 47 ++++++++--------------------------------------- 1 file changed, 8 insertions(+), 39 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 02a1727d..c0d22a99 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,13 +1,12 @@ import sys -import numpy as np import torch -from PIL import Image import modules.esrgan_model_arch as arch -from modules import modelloader, images, devices +from modules import modelloader, devices from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData +from modules.upscaler_utils import upscale_with_model def mod2normal(state_dict): @@ -190,40 +189,10 @@ class UpscalerESRGAN(Upscaler): return model -def upscale_without_tiling(model, img): - img = np.array(img) - img = img[:, :, ::-1] - img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 - img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(devices.device_esrgan) - with torch.no_grad(): - output = model(img) - output = output.squeeze().float().cpu().clamp_(0, 1).numpy() - output = 255. * np.moveaxis(output, 0, 2) - output = output.astype(np.uint8) - output = output[:, :, ::-1] - return Image.fromarray(output, 'RGB') - - def esrgan_upscale(model, img): - if opts.ESRGAN_tile == 0: - return upscale_without_tiling(model, img) - - grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) - newtiles = [] - scale_factor = 1 - - for y, h, row in grid.tiles: - newrow = [] - for tiledata in row: - x, w, tile = tiledata - - output = upscale_without_tiling(model, tile) - scale_factor = output.width // tile.width - - newrow.append([x * scale_factor, w * scale_factor, output]) - newtiles.append([y * scale_factor, h * scale_factor, newrow]) - - newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) - output = images.combine_grid(newgrid) - return output + return upscale_with_model( + model, + img, + tile_size=opts.ESRGAN_tile, + tile_overlap=opts.ESRGAN_tile_overlap, + ) -- cgit v1.2.3 From b0f59342346b1c8b405f97c0e0bb01c6ae05c601 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 14:43:51 +0200 Subject: Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR) --- modules/esrgan_model.py | 153 +++--------------------------------------------- 1 file changed, 8 insertions(+), 145 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index c0d22a99..a7c7c9e3 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,122 +1,9 @@ -import sys - -import torch - -import modules.esrgan_model_arch as arch -from modules import modelloader, devices +from modules import modelloader, devices, errors from modules.shared import opts from modules.upscaler import Upscaler, UpscalerData from modules.upscaler_utils import upscale_with_model -def mod2normal(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - if 'conv_first.weight' in state_dict: - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if 'RDB' in k: - ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net['model.1.sub.23.weight'] = state_dict['trunk_conv.weight'] - crt_net['model.1.sub.23.bias'] = state_dict['trunk_conv.bias'] - crt_net['model.3.weight'] = state_dict['upconv1.weight'] - crt_net['model.3.bias'] = state_dict['upconv1.bias'] - crt_net['model.6.weight'] = state_dict['upconv2.weight'] - crt_net['model.6.bias'] = state_dict['upconv2.bias'] - crt_net['model.8.weight'] = state_dict['HRconv.weight'] - crt_net['model.8.bias'] = state_dict['HRconv.bias'] - crt_net['model.10.weight'] = state_dict['conv_last.weight'] - crt_net['model.10.bias'] = state_dict['conv_last.bias'] - state_dict = crt_net - return state_dict - - -def resrgan2normal(state_dict, nb=23): - # this code is copied from https://github.com/victorca25/iNNfer - if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: - re8x = 0 - crt_net = {} - items = list(state_dict) - - crt_net['model.0.weight'] = state_dict['conv_first.weight'] - crt_net['model.0.bias'] = state_dict['conv_first.bias'] - - for k in items.copy(): - if "rdb" in k: - ori_k = k.replace('body.', 'model.1.sub.') - ori_k = ori_k.replace('.rdb', '.RDB') - if '.weight' in k: - ori_k = ori_k.replace('.weight', '.0.weight') - elif '.bias' in k: - ori_k = ori_k.replace('.bias', '.0.bias') - crt_net[ori_k] = state_dict[k] - items.remove(k) - - crt_net[f'model.1.sub.{nb}.weight'] = state_dict['conv_body.weight'] - crt_net[f'model.1.sub.{nb}.bias'] = state_dict['conv_body.bias'] - crt_net['model.3.weight'] = state_dict['conv_up1.weight'] - crt_net['model.3.bias'] = state_dict['conv_up1.bias'] - crt_net['model.6.weight'] = state_dict['conv_up2.weight'] - crt_net['model.6.bias'] = state_dict['conv_up2.bias'] - - if 'conv_up3.weight' in state_dict: - # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py - re8x = 3 - crt_net['model.9.weight'] = state_dict['conv_up3.weight'] - crt_net['model.9.bias'] = state_dict['conv_up3.bias'] - - crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight'] - crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias'] - crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight'] - crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias'] - - state_dict = crt_net - return state_dict - - -def infer_params(state_dict): - # this code is copied from https://github.com/victorca25/iNNfer - scale2x = 0 - scalemin = 6 - n_uplayer = 0 - plus = False - - for block in list(state_dict): - parts = block.split(".") - n_parts = len(parts) - if n_parts == 5 and parts[2] == "sub": - nb = int(parts[3]) - elif n_parts == 3: - part_num = int(parts[1]) - if (part_num > scalemin - and parts[0] == "model" - and parts[2] == "weight"): - scale2x += 1 - if part_num > n_uplayer: - n_uplayer = part_num - out_nc = state_dict[block].shape[0] - if not plus and "conv1x1" in block: - plus = True - - nf = state_dict["model.0.weight"].shape[0] - in_nc = state_dict["model.0.weight"].shape[1] - out_nc = out_nc - scale = 2 ** scale2x - - return in_nc, out_nc, nf, nb, plus, scale - - class UpscalerESRGAN(Upscaler): def __init__(self, dirname): self.name = "ESRGAN" @@ -142,12 +29,11 @@ class UpscalerESRGAN(Upscaler): def do_upscale(self, img, selected_model): try: model = self.load_model(selected_model) - except Exception as e: - print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) + except Exception: + errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True) return img model.to(devices.device_esrgan) - img = esrgan_upscale(model, img) - return img + return esrgan_upscale(model, img) def load_model(self, path: str): if path.startswith("http"): @@ -160,33 +46,10 @@ class UpscalerESRGAN(Upscaler): else: filename = path - state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) - - if "params_ema" in state_dict: - state_dict = state_dict["params_ema"] - elif "params" in state_dict: - state_dict = state_dict["params"] - num_conv = 16 if "realesr-animevideov3" in filename else 32 - model = arch.SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=num_conv, upscale=4, act_type='prelu') - model.load_state_dict(state_dict) - model.eval() - return model - - if "body.0.rdb1.conv1.weight" in state_dict and "conv_first.weight" in state_dict: - nb = 6 if "RealESRGAN_x4plus_anime_6B" in filename else 23 - state_dict = resrgan2normal(state_dict, nb) - elif "conv_first.weight" in state_dict: - state_dict = mod2normal(state_dict) - elif "model.0.weight" not in state_dict: - raise Exception("The file is not a recognized ESRGAN model.") - - in_nc, out_nc, nf, nb, plus, mscale = infer_params(state_dict) - - model = arch.RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb, upscale=mscale, plus=plus) - model.load_state_dict(state_dict) - model.eval() - - return model + return modelloader.load_spandrel_model( + filename, + device=('cpu' if devices.device_esrgan.type == 'mps' else None), + ) def esrgan_upscale(model, img): -- cgit v1.2.3 From 4ad0c0c0a805da4bac03cff86ea17c25a1291546 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 16:37:03 +0200 Subject: Verify architecture for loaded Spandrel models --- modules/esrgan_model.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a7c7c9e3..70041ab0 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler): return modelloader.load_spandrel_model( filename, device=('cpu' if devices.device_esrgan.type == 'mps' else None), + expected_architecture='ESRGAN', ) -- cgit v1.2.3