From f299645aeeb65fcddde2d136fd550b6b01ffebb3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 4 Sep 2022 18:54:12 +0300 Subject: ESRGAN support --- modules/esrgan_model.py | 134 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 modules/esrgan_model.py (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py new file mode 100644 index 00000000..3dcef5a6 --- /dev/null +++ b/modules/esrgan_model.py @@ -0,0 +1,134 @@ +import os +import sys +import traceback + +import numpy as np +import torch +from PIL import Image + +import modules.esrgam_model_arch as arch +from modules import shared +from modules.shared import opts +import modules.images + + +def load_model(filename): + # this code is adapted from https://github.com/xinntao/ESRGAN + + pretrained_net = torch.load(filename) + crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) + + if 'conv_first.weight' in pretrained_net: + crt_model.load_state_dict(pretrained_net) + return crt_model + + crt_net = crt_model.state_dict() + load_net_clean = {} + for k, v in pretrained_net.items(): + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + pretrained_net = load_net_clean + + tbd = [] + for k, v in crt_net.items(): + tbd.append(k) + + # directly copy + for k, v in crt_net.items(): + if k in pretrained_net and pretrained_net[k].size() == v.size(): + crt_net[k] = pretrained_net[k] + tbd.remove(k) + + crt_net['conv_first.weight'] = pretrained_net['model.0.weight'] + crt_net['conv_first.bias'] = pretrained_net['model.0.bias'] + + for k in tbd.copy(): + if 'RDB' in k: + ori_k = k.replace('RRDB_trunk.', 'model.1.sub.') + if '.weight' in k: + ori_k = ori_k.replace('.weight', '.0.weight') + elif '.bias' in k: + ori_k = ori_k.replace('.bias', '.0.bias') + crt_net[k] = pretrained_net[ori_k] + tbd.remove(k) + + crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight'] + crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias'] + crt_net['upconv1.weight'] = pretrained_net['model.3.weight'] + crt_net['upconv1.bias'] = pretrained_net['model.3.bias'] + crt_net['upconv2.weight'] = pretrained_net['model.6.weight'] + crt_net['upconv2.bias'] = pretrained_net['model.6.bias'] + crt_net['HRconv.weight'] = pretrained_net['model.8.weight'] + crt_net['HRconv.bias'] = pretrained_net['model.8.bias'] + crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] + crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] + + crt_model.load_state_dict(crt_net) + crt_model.eval() + return crt_model + +def upscale_without_tiling(model, img): + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(shared.device) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + return Image.fromarray(output, 'RGB') + + +def esrgan_upscale(model, img): + if opts.ESRGAN_tile == 0: + return upscale_without_tiling(model, img) + + grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) + newtiles = [] + scale_factor = 1 + + for y, h, row in grid.tiles: + newrow = [] + for tiledata in row: + x, w, tile = tiledata + + output = upscale_without_tiling(model, tile) + scale_factor = output.width // tile.width + + newrow.append([x * scale_factor, w * scale_factor, output]) + newtiles.append([y * scale_factor, h * scale_factor, newrow]) + + newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) + output = modules.images.combine_grid(newgrid) + return output + + +class UpscalerESRGAN(modules.images.Upscaler): + def __init__(self, filename, title): + self.name = title + self.model = load_model(filename) + + def do_upscale(self, img): + model = self.model.to(shared.device) + img = esrgan_upscale(model, img) + return img + + +def load_models(dirname): + for file in os.listdir(dirname): + path = os.path.join(dirname, file) + model_name, extension = os.path.splitext(file) + + if extension != '.pt' and extension != '.pth': + continue + + try: + modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name)) + except Exception: + print(f"Error loading ESRGAN model: {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) -- cgit v1.2.3