aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
Diffstat (limited to 'modules')
-rw-r--r--modules/bsrgan_model.py79
-rw-r--r--modules/bsrgan_model_arch.py103
-rw-r--r--modules/esrgan_model.py227
-rw-r--r--modules/extras.py35
-rw-r--r--modules/gfpgan_model.py58
-rw-r--r--modules/gfpgan_model_arch.py150
-rw-r--r--modules/images.py84
-rw-r--r--modules/ldsr_model.py103
-rw-r--r--modules/ldsr_model_arch.py223
-rw-r--r--modules/modelloader.py74
-rw-r--r--modules/realesrgan_model.py209
-rw-r--r--modules/sd_models.py3
-rw-r--r--modules/sd_samplers.py4
-rw-r--r--modules/shared.py18
-rw-r--r--modules/swinir_model.py157
-rw-r--r--modules/upscaler.py121
16 files changed, 1009 insertions, 639 deletions
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py
new file mode 100644
index 00000000..77141545
--- /dev/null
+++ b/modules/bsrgan_model.py
@@ -0,0 +1,79 @@
+import os.path
+import sys
+import traceback
+
+import PIL.Image
+import numpy as np
+import torch
+from basicsr.utils.download_util import load_file_from_url
+
+import modules.upscaler
+from modules import shared, modelloader
+from modules.bsrgan_model_arch import RRDBNet
+from modules.paths import models_path
+
+
+class UpscalerBSRGAN(modules.upscaler.Upscaler):
+ def __init__(self, dirname):
+ self.name = "BSRGAN"
+ self.model_path = os.path.join(models_path, self.name)
+ self.model_name = "BSRGAN 4x"
+ self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
+ self.user_path = dirname
+ super().__init__()
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
+ scalers = []
+ if len(model_paths) == 0:
+ scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
+ scalers.append(scaler_data)
+ for file in model_paths:
+ if "http" in file:
+ name = self.model_name
+ else:
+ name = modelloader.friendly_name(file)
+ try:
+ scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
+ scalers.append(scaler_data)
+ except Exception:
+ print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ self.scalers = scalers
+
+ def do_upscale(self, img: PIL.Image, selected_file):
+ torch.cuda.empty_cache()
+ model = self.load_model(selected_file)
+ if model is None:
+ return img
+ model.to(shared.device)
+ torch.cuda.empty_cache()
+ img = np.array(img)
+ img = img[:, :, ::-1]
+ img = np.moveaxis(img, 2, 0) / 255
+ img = torch.from_numpy(img).float()
+ img = img.unsqueeze(0).to(shared.device)
+ with torch.no_grad():
+ output = model(img)
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output = 255. * np.moveaxis(output, 0, 2)
+ output = output.astype(np.uint8)
+ output = output[:, :, ::-1]
+ torch.cuda.empty_cache()
+ return PIL.Image.fromarray(output, 'RGB')
+
+ def load_model(self, path: str):
+ if "http" in path:
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
+ progress=True)
+ else:
+ filename = path
+ if not os.path.exists(filename) or filename is None:
+ print("Unable to load %s from %s" % (self.model_dir, filename))
+ return None
+ print("Loading %s from %s" % (self.model_dir, filename))
+ model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network
+ model.load_state_dict(torch.load(filename), strict=True)
+ model.eval()
+ for k, v in model.named_parameters():
+ v.requires_grad = False
+ return model
+
diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py
new file mode 100644
index 00000000..d72647db
--- /dev/null
+++ b/modules/bsrgan_model_arch.py
@@ -0,0 +1,103 @@
+import functools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+
+
+def initialize_weights(net_l, scale=1):
+ if not isinstance(net_l, list):
+ net_l = [net_l]
+ for net in net_l:
+ for m in net.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
+ m.weight.data *= scale # for residual block
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, a=0, mode='fan_in')
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ init.constant_(m.weight, 1)
+ init.constant_(m.bias.data, 0.0)
+
+
+def make_layer(block, n_layers):
+ layers = []
+ for _ in range(n_layers):
+ layers.append(block())
+ return nn.Sequential(*layers)
+
+
+class ResidualDenseBlock_5C(nn.Module):
+ def __init__(self, nf=64, gc=32, bias=True):
+ super(ResidualDenseBlock_5C, self).__init__()
+ # gc: growth channel, i.e. intermediate channels
+ self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
+ self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
+ self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
+ self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
+ self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ # initialization
+ initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+ '''Residual in Residual Dense Block'''
+
+ def __init__(self, nf, gc=32):
+ super(RRDB, self).__init__()
+ self.RDB1 = ResidualDenseBlock_5C(nf, gc)
+ self.RDB2 = ResidualDenseBlock_5C(nf, gc)
+ self.RDB3 = ResidualDenseBlock_5C(nf, gc)
+
+ def forward(self, x):
+ out = self.RDB1(x)
+ out = self.RDB2(out)
+ out = self.RDB3(out)
+ return out * 0.2 + x
+
+
+class RRDBNet(nn.Module):
+ def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
+ super(RRDBNet, self).__init__()
+ RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
+ self.sf = sf
+ print([in_nc, out_nc, nf, nb, gc, sf])
+
+ self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
+ self.RRDB_trunk = make_layer(RRDB_block_f, nb)
+ self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ #### upsampling
+ self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ if self.sf==4:
+ self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
+ self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ fea = self.conv_first(x)
+ trunk = self.trunk_conv(self.RRDB_trunk(fea))
+ fea = fea + trunk
+
+ fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
+ if self.sf==4:
+ fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
+ out = self.conv_last(self.lrelu(self.HRconv(fea)))
+
+ return out \ No newline at end of file
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 5e10c49c..ce841aa4 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,6 +1,4 @@
import os
-import sys
-import traceback
import numpy as np
import torch
@@ -8,93 +6,119 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch
-import modules.images
-from modules import shared
-from modules import shared, modelloader
+from modules import shared, modelloader, images
from modules.devices import has_mps
from modules.paths import models_path
+from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-model_dir = "ESRGAN"
-model_path = os.path.join(models_path, model_dir)
-model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
-model_name = "ESRGAN_x4"
-
-
-def load_model(path: str, name: str):
- global model_path
- global model_url
- global model_dir
- global model_name
- if "http" in path:
- filename = load_file_from_url(url=model_url, model_dir=model_path, file_name="%s.pth" % model_name, progress=True)
- else:
- filename = path
- if not os.path.exists(filename) or filename is None:
- print("Unable to load %s from %s" % (model_dir, filename))
- return None
- print("Loading %s from %s" % (model_dir, filename))
- # this code is adapted from https://github.com/xinntao/ESRGAN
- pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
- crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
-
- if 'conv_first.weight' in pretrained_net:
- crt_model.load_state_dict(pretrained_net)
- return crt_model
- if 'model.0.weight' not in pretrained_net:
- is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
- if is_realesrgan:
- raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
- else:
- raise Exception("The file is not a ESRGAN model.")
+class UpscalerESRGAN(Upscaler):
+ def __init__(self, dirname):
+ self.name = "ESRGAN"
+ self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
+ self.model_name = "ESRGAN 4x"
+ self.scalers = []
+ self.user_path = dirname
+ self.model_path = os.path.join(models_path, self.name)
+ super().__init__()
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
+ scalers = []
+ if len(model_paths) == 0:
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
+ scalers.append(scaler_data)
+ for file in model_paths:
+ print(f"File: {file}")
+ if "http" in file:
+ name = self.model_name
+ else:
+ name = modelloader.friendly_name(file)
+
+ scaler_data = UpscalerData(name, file, self, 4)
+ print(f"ESRGAN: Adding scaler {name}")
+ self.scalers.append(scaler_data)
+
+ def do_upscale(self, img, selected_model):
+ model = self.load_model(selected_model)
+ if model is None:
+ return img
+ model.to(shared.device)
+ img = esrgan_upscale(model, img)
+ return img
- crt_net = crt_model.state_dict()
- load_net_clean = {}
- for k, v in pretrained_net.items():
- if k.startswith('module.'):
- load_net_clean[k[7:]] = v
+ def load_model(self, path: str):
+ if "http" in path:
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="%s.pth" % self.model_name,
+ progress=True)
else:
- load_net_clean[k] = v
- pretrained_net = load_net_clean
-
- tbd = []
- for k, v in crt_net.items():
- tbd.append(k)
-
- # directly copy
- for k, v in crt_net.items():
- if k in pretrained_net and pretrained_net[k].size() == v.size():
- crt_net[k] = pretrained_net[k]
- tbd.remove(k)
-
- crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
- crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
-
- for k in tbd.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[k] = pretrained_net[ori_k]
- tbd.remove(k)
-
- crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
- crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
- crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
- crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
- crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
- crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
- crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
- crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
- crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
- crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
-
- crt_model.load_state_dict(crt_net)
- crt_model.eval()
- return crt_model
+ filename = path
+ if not os.path.exists(filename) or filename is None:
+ print("Unable to load %s from %s" % (self.model_path, filename))
+ return None
+ # this code is adapted from https://github.com/xinntao/ESRGAN
+ pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
+ crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+
+ if 'conv_first.weight' in pretrained_net:
+ crt_model.load_state_dict(pretrained_net)
+ return crt_model
+
+ if 'model.0.weight' not in pretrained_net:
+ is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
+ "params_ema"]
+ if is_realesrgan:
+ raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
+ else:
+ raise Exception("The file is not a ESRGAN model.")
+
+ crt_net = crt_model.state_dict()
+ load_net_clean = {}
+ for k, v in pretrained_net.items():
+ if k.startswith('module.'):
+ load_net_clean[k[7:]] = v
+ else:
+ load_net_clean[k] = v
+ pretrained_net = load_net_clean
+
+ tbd = []
+ for k, v in crt_net.items():
+ tbd.append(k)
+
+ # directly copy
+ for k, v in crt_net.items():
+ if k in pretrained_net and pretrained_net[k].size() == v.size():
+ crt_net[k] = pretrained_net[k]
+ tbd.remove(k)
+
+ crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
+ crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
+
+ for k in tbd.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[k] = pretrained_net[ori_k]
+ tbd.remove(k)
+
+ crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
+ crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
+ crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
+ crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
+ crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
+ crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
+ crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
+ crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
+ crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
+ crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
+
+ crt_model.load_state_dict(crt_net)
+ crt_model.eval()
+ return crt_model
+
def upscale_without_tiling(model, img):
img = np.array(img)
@@ -115,7 +139,7 @@ def esrgan_upscale(model, img):
if opts.ESRGAN_tile == 0:
return upscale_without_tiling(model, img)
- grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
newtiles = []
scale_factor = 1
@@ -130,38 +154,7 @@ def esrgan_upscale(model, img):
newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow])
- newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = modules.images.combine_grid(newgrid)
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor,
+ grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
+ output = images.combine_grid(newgrid)
return output
-
-
-class UpscalerESRGAN(modules.images.Upscaler):
- def __init__(self, filename, title):
- self.name = title
- self.filename = filename
-
- def do_upscale(self, img):
- model = load_model(self.filename, self.name)
- if model is None:
- return img
- model.to(shared.device)
- img = esrgan_upscale(model, img)
- return img
-
-
-def setup_model(dirname):
- global model_path
- global model_name
- if not os.path.exists(model_path):
- os.makedirs(model_path)
-
- model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
- if len(model_paths) == 0:
- modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
- for file in model_paths:
- name = modelloader.friendly_name(file)
- try:
- modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
- except Exception:
- print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/extras.py b/modules/extras.py
index af6e631f..d7d0fa54 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -66,29 +66,28 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res
- if upscaling_resize != 1.0:
- def upscale(image, scaler_index, resize):
- small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
- pixels = tuple(np.array(small).flatten().tolist())
- key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
+ def upscale(image, scaler_index, resize):
+ small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
+ pixels = tuple(np.array(small).flatten().tolist())
+ key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
- c = cached_images.get(key)
- if c is None:
- upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.upscale(image, image.width * resize, image.height * resize)
- cached_images[key] = c
+ c = cached_images.get(key)
+ if c is None:
+ upscaler = shared.sd_upscalers[scaler_index]
+ c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+ cached_images[key] = c
- return c
+ return c
- info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
- res = upscale(image, extras_upscaler_1, upscaling_resize)
+ info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
+ res = upscale(image, extras_upscaler_1, upscaling_resize)
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- res2 = upscale(image, extras_upscaler_2, upscaling_resize)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
- res = Image.blend(res, res2, extras_upscaler_2_visibility)
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
+ res2 = upscale(image, extras_upscaler_2, upscaling_resize)
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
+ res = Image.blend(res, res2, extras_upscaler_2_visibility)
- image = res
+ image = res
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py
index ffb6960d..2bf8a1ee 100644
--- a/modules/gfpgan_model.py
+++ b/modules/gfpgan_model.py
@@ -1,24 +1,23 @@
import os
import sys
import traceback
-from glob import glob
-from modules import shared, devices
-from modules.shared import cmd_opts
-from modules.paths import script_path
+import facexlib
+import gfpgan
+
import modules.face_restoration
from modules import shared, devices, modelloader
from modules.paths import models_path
model_dir = "GFPGAN"
-cmd_dir = None
+user_path = None
model_path = os.path.join(models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-
+have_gfpgan = False
loaded_gfpgan_model = None
-def gfpgan():
+def gfpgann():
global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None:
@@ -28,14 +27,16 @@ def gfpgan():
if gfpgan_constructor is None:
return None
- models = modelloader.load_models(model_path, model_url, cmd_dir)
- if len(models) != 0:
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
+ if len(models) == 1 and "http" in models[0]:
+ model_file = models[0]
+ elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
model_file = latest_file
else:
print("Unable to load gfpgan model!")
return None
- model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2,
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2,
bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
@@ -44,11 +45,12 @@ def gfpgan():
def gfpgan_fix_faces(np_image):
- model = gfpgan()
+ model = gfpgann()
if model is None:
return np_image
np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False,
+ only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
if shared.opts.face_restoration_unload:
@@ -57,7 +59,6 @@ def gfpgan_fix_faces(np_image):
return np_image
-have_gfpgan = False
gfpgan_constructor = None
@@ -67,14 +68,33 @@ def setup_model(dirname):
os.makedirs(model_path)
try:
- from modules.gfpgan_model_arch import GFPGANerr
- global cmd_dir
+ from gfpgan import GFPGANer
+ from facexlib import detection, parsing
+ global user_path
global have_gfpgan
global gfpgan_constructor
- cmd_dir = dirname
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
+
+ def my_load_file_from_url(**kwargs):
+ print("Setting model_dir to " + model_path)
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
+
+ def facex_load_file_from_url(**kwargs):
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ def facex_load_file_from_url2(**kwargs):
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
+ user_path = dirname
+ print("Have gfpgan should be true?")
have_gfpgan = True
- gfpgan_constructor = GFPGANerr
+ gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self):
@@ -82,7 +102,9 @@ def setup_model(dirname):
def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False,
+ only_center_face=False,
+ paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image
diff --git a/modules/gfpgan_model_arch.py b/modules/gfpgan_model_arch.py
deleted file mode 100644
index d81cea96..00000000
--- a/modules/gfpgan_model_arch.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original...
-
-import cv2
-import os
-import torch
-from basicsr.utils import img2tensor, tensor2img
-from basicsr.utils.download_util import load_file_from_url
-from facexlib.utils.face_restoration_helper import FaceRestoreHelper
-from torchvision.transforms.functional import normalize
-
-from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
-from gfpgan.archs.gfpganv1_arch import GFPGANv1
-from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
-
-ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
-
-
-class GFPGANerr():
- """Helper for restoration with GFPGAN.
-
- It will detect and crop faces, and then resize the faces to 512x512.
- GFPGAN is used to restored the resized faces.
- The background is upsampled with the bg_upsampler.
- Finally, the faces will be pasted back to the upsample background image.
-
- Args:
- model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
- upscale (float): The upscale of the final output. Default: 2.
- arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
- channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
- bg_upsampler (nn.Module): The upsampler for the background. Default: None.
- """
-
- def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
- self.upscale = upscale
- self.bg_upsampler = bg_upsampler
-
- # initialize model
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
- # initialize the GFP-GAN
- if arch == 'clean':
- self.gfpgan = GFPGANv1Clean(
- out_size=512,
- num_style_feat=512,
- channel_multiplier=channel_multiplier,
- decoder_load_path=None,
- fix_decoder=False,
- num_mlp=8,
- input_is_latent=True,
- different_w=True,
- narrow=1,
- sft_half=True)
- elif arch == 'bilinear':
- self.gfpgan = GFPGANBilinear(
- out_size=512,
- num_style_feat=512,
- channel_multiplier=channel_multiplier,
- decoder_load_path=None,
- fix_decoder=False,
- num_mlp=8,
- input_is_latent=True,
- different_w=True,
- narrow=1,
- sft_half=True)
- elif arch == 'original':
- self.gfpgan = GFPGANv1(
- out_size=512,
- num_style_feat=512,
- channel_multiplier=channel_multiplier,
- decoder_load_path=None,
- fix_decoder=True,
- num_mlp=8,
- input_is_latent=True,
- different_w=True,
- narrow=1,
- sft_half=True)
- elif arch == 'RestoreFormer':
- from gfpgan.archs.restoreformer_arch import RestoreFormer
- self.gfpgan = RestoreFormer()
- # initialize face helper
- self.face_helper = FaceRestoreHelper(
- upscale,
- face_size=512,
- crop_ratio=(1, 1),
- det_model='retinaface_resnet50',
- save_ext='png',
- use_parse=True,
- device=self.device,
- model_rootpath=model_dir)
-
- if model_path.startswith('https://'):
- model_path = load_file_from_url(
- url=model_path, model_dir=model_dir, progress=True, file_name=None)
- loadnet = torch.load(model_path)
- if 'params_ema' in loadnet:
- keyname = 'params_ema'
- else:
- keyname = 'params'
- self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
- self.gfpgan.eval()
- self.gfpgan = self.gfpgan.to(self.device)
-
- @torch.no_grad()
- def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
- self.face_helper.clean_all()
-
- if has_aligned: # the inputs are already aligned
- img = cv2.resize(img, (512, 512))
- self.face_helper.cropped_faces = [img]
- else:
- self.face_helper.read_image(img)
- # get face landmarks for each face
- self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
- # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
- # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
- # align and warp each face
- self.face_helper.align_warp_face()
-
- # face restoration
- for cropped_face in self.face_helper.cropped_faces:
- # prepare data
- cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
- normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
- cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
-
- try:
- output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
- # convert to image
- restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
- except RuntimeError as error:
- print(f'\tFailed inference for GFPGAN: {error}.')
- restored_face = cropped_face
-
- restored_face = restored_face.astype('uint8')
- self.face_helper.add_restored_face(restored_face)
-
- if not has_aligned and paste_back:
- # upsample the background
- if self.bg_upsampler is not None:
- # Now only support RealESRGAN for upsampling background
- bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
- else:
- bg_img = None
-
- self.face_helper.get_inverse_affine(None)
- # paste each restored face to the input image
- restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
- return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
- else:
- return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
diff --git a/modules/images.py b/modules/images.py
index 9458bf8d..a6538dbe 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string
-import modules.shared
from modules import sd_samplers, shared
from modules.shared import opts, cmd_opts
@@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
cols = math.ceil((w - overlap) / non_overlap_width)
rows = math.ceil((h - overlap) / non_overlap_height)
- dx = (w - tile_w) / (cols-1) if cols > 1 else 0
- dy = (h - tile_h) / (rows-1) if rows > 1 else 0
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0