diff options
-rw-r--r-- | .gitignore | 7 | ||||
-rw-r--r-- | ESRGAN/Put ESRGAN models here.txt | 0 | ||||
-rw-r--r-- | SwinIR/put_swinir_models_here.txt | 1 | ||||
-rw-r--r-- | launch.py | 11 | ||||
-rw-r--r-- | modules/bsrgan_model.py | 79 | ||||
-rw-r--r-- | modules/bsrgan_model_arch.py | 103 | ||||
-rw-r--r-- | modules/codeformer_model.py | 44 | ||||
-rw-r--r-- | modules/esrgan_model.py | 205 | ||||
-rw-r--r-- | modules/extras.py | 37 | ||||
-rw-r--r-- | modules/gfpgan_model.py | 100 | ||||
-rw-r--r-- | modules/images.py | 84 | ||||
-rw-r--r-- | modules/ldsr_model.py | 92 | ||||
-rw-r--r-- | modules/ldsr_model_arch.py | 225 | ||||
-rw-r--r-- | modules/modelloader.py | 133 | ||||
-rw-r--r-- | modules/paths.py | 3 | ||||
-rw-r--r-- | modules/realesrgan_model.py | 202 | ||||
-rw-r--r-- | modules/sd_models.py | 61 | ||||
-rw-r--r-- | modules/shared.py | 45 | ||||
-rw-r--r-- | modules/swinir.py | 123 | ||||
-rw-r--r-- | modules/swinir_model.py | 139 | ||||
-rw-r--r-- | modules/swinir_model_arch.py (renamed from modules/swinir_arch.py) | 1734 | ||||
-rw-r--r-- | modules/upscaler.py | 121 | ||||
-rw-r--r-- | webui.py | 46 |
23 files changed, 2179 insertions, 1416 deletions
@@ -1,10 +1,13 @@ __pycache__ -/ESRGAN +*.ckpt +*.pth +/ESRGAN/* +/SwinIR/* /repositories /venv /tmp /model.ckpt -/models/**/*.ckpt +/models/**/* /GFPGANv1.3.pth /gfpgan/weights/*.pth /ui-config.json diff --git a/ESRGAN/Put ESRGAN models here.txt b/ESRGAN/Put ESRGAN models here.txt deleted file mode 100644 index e69de29b..00000000 --- a/ESRGAN/Put ESRGAN models here.txt +++ /dev/null diff --git a/SwinIR/put_swinir_models_here.txt b/SwinIR/put_swinir_models_here.txt deleted file mode 100644 index 8b137891..00000000 --- a/SwinIR/put_swinir_models_here.txt +++ /dev/null @@ -1 +0,0 @@ - @@ -1,5 +1,5 @@ # this scripts installs necessary requirements and launches main program in webui.py
-
+import shutil
import subprocess
import os
import sys
@@ -22,7 +22,6 @@ taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HAS k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
-ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH', "abf33e7002d59d9085081bce93ec798dcabd49af")
args = shlex.split(commandline_args)
@@ -120,9 +119,11 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
-# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
-git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash)
-
+if os.path.isdir(repo_dir('latent-diffusion')):
+ try:
+ shutil.rmtree(repo_dir('latent-diffusion'))
+ except:
+ pass
if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py new file mode 100644 index 00000000..77141545 --- /dev/null +++ b/modules/bsrgan_model.py @@ -0,0 +1,79 @@ +import os.path +import sys +import traceback + +import PIL.Image +import numpy as np +import torch +from basicsr.utils.download_util import load_file_from_url + +import modules.upscaler +from modules import shared, modelloader +from modules.bsrgan_model_arch import RRDBNet +from modules.paths import models_path + + +class UpscalerBSRGAN(modules.upscaler.Upscaler): + def __init__(self, dirname): + self.name = "BSRGAN" + self.model_path = os.path.join(models_path, self.name) + self.model_name = "BSRGAN 4x" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" + self.user_path = dirname + super().__init__() + model_paths = self.find_models(ext_filter=[".pt", ".pth"]) + scalers = [] + if len(model_paths) == 0: + scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4) + scalers.append(scaler_data) + for file in model_paths: + if "http" in file: + name = self.model_name + else: + name = modelloader.friendly_name(file) + try: + scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) + scalers.append(scaler_data) + except Exception: + print(f"Error loading BSRGAN model: {file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + self.scalers = scalers + + def do_upscale(self, img: PIL.Image, selected_file): + torch.cuda.empty_cache() + model = self.load_model(selected_file) + if model is None: + return img + model.to(shared.device) + torch.cuda.empty_cache() + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(shared.device) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + torch.cuda.empty_cache() + return PIL.Image.fromarray(output, 'RGB') + + def load_model(self, path: str): + if "http" in path: + filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, + progress=True) + else: + filename = path + if not os.path.exists(filename) or filename is None: + print("Unable to load %s from %s" % (self.model_dir, filename)) + return None + print("Loading %s from %s" % (self.model_dir, filename)) + model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network + model.load_state_dict(torch.load(filename), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + return model + diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py new file mode 100644 index 00000000..d72647db --- /dev/null +++ b/modules/bsrgan_model_arch.py @@ -0,0 +1,103 @@ +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + self.sf = sf + print([in_nc, out_nc, nf, nb, gc, sf]) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + if self.sf==4: + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + if self.sf==4: + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out
\ No newline at end of file diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index 8fbdea24..efd881eb 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -5,31 +5,31 @@ import traceback import cv2
import torch
-from modules import shared, devices
-from modules.paths import script_path
-import modules.shared
import modules.face_restoration
-from importlib import reload
+import modules.shared
+from modules import shared, devices, modelloader
+from modules.paths import script_path, models_path
-# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
-# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
+# codeformer people made a choice to include modified basicsr library to their project which makes
+# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
-
-pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+model_dir = "Codeformer"
+model_path = os.path.join(models_path, model_dir)
+model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False
codeformer = None
-def setup_codeformer():
+
+def setup_model(dirname):
+ global model_path
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+
path = modules.paths.paths.get("CodeFormer", None)
if path is None:
return
-
- # both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
- #stored_sys_path = sys.path
- #sys.path = [path] + sys.path
-
try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
@@ -44,18 +44,23 @@ def setup_codeformer(): def name(self):
return "CodeFormer"
- def __init__(self):
+ def __init__(self, dirname):
self.net = None
self.face_helper = None
+ self.cmd_dir = dirname
def create_models(self):
if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer)
return self.net, self.face_helper
-
+ model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir)
+ if len(model_paths) != 0:
+ ckpt_path = model_paths[0]
+ else:
+ print("Unable to load codeformer model.")
+ return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
- ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
@@ -74,6 +79,9 @@ def setup_codeformer(): original_resolution = np_image.shape[0:2]
self.create_models()
+ if self.net is None or self.face_helper is None:
+ return np_image
+
self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -114,7 +122,7 @@ def setup_codeformer(): have_codeformer = True
global codeformer
- codeformer = FaceRestorerCodeFormer()
+ codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
except Exception:
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 7f3baf31..ce841aa4 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,80 +1,124 @@ import os
-import sys
-import traceback
import numpy as np
import torch
from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch
-from modules import shared
-from modules.shared import opts
+from modules import shared, modelloader, images
from modules.devices import has_mps
-import modules.images
-
+from modules.paths import models_path
+from modules.upscaler import Upscaler, UpscalerData
+from modules.shared import opts
-def load_model(filename):
- # this code is adapted from https://github.com/xinntao/ESRGAN
- pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
- crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
- if 'conv_first.weight' in pretrained_net:
- crt_model.load_state_dict(pretrained_net)
- return crt_model
+class UpscalerESRGAN(Upscaler):
+ def __init__(self, dirname):
+ self.name = "ESRGAN"
+ self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
+ self.model_name = "ESRGAN 4x"
+ self.scalers = []
+ self.user_path = dirname
+ self.model_path = os.path.join(models_path, self.name)
+ super().__init__()
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
+ scalers = []
+ if len(model_paths) == 0:
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
+ scalers.append(scaler_data)
+ for file in model_paths:
+ print(f"File: {file}")
+ if "http" in file:
+ name = self.model_name
+ else:
+ name = modelloader.friendly_name(file)
+
+ scaler_data = UpscalerData(name, file, self, 4)
+ print(f"ESRGAN: Adding scaler {name}")
+ self.scalers.append(scaler_data)
+
+ def do_upscale(self, img, selected_model):
+ model = self.load_model(selected_model)
+ if model is None:
+ return img
+ model.to(shared.device)
+ img = esrgan_upscale(model, img)
+ return img
- if 'model.0.weight' not in pretrained_net:
- is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
- if is_realesrgan:
- raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
+ def load_model(self, path: str):
+ if "http" in path:
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="%s.pth" % self.model_name,
+ progress=True)
else:
- raise Exception("The file is not a ESRGAN model.")
+ filename = path
+ if not os.path.exists(filename) or filename is None:
+ print("Unable to load %s from %s" % (self.model_path, filename))
+ return None
+ # this code is adapted from https://github.com/xinntao/ESRGAN
+ pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
+ crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+
+ if 'conv_first.weight' in pretrained_net:
+ crt_model.load_state_dict(pretrained_net)
+ return crt_model
+
+ if 'model.0.weight' not in pretrained_net:
+ is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
+ "params_ema"]
+ if is_realesrgan:
+ raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
+ else:
+ raise Exception("The file is not a ESRGAN model.")
+
+ crt_net = crt_model.state_dict()
+ load_net_clean = {}
+ for k, v in pretrained_net.items():
+ if k.startswith('module.'):
+ load_net_clean[k[7:]] = v
+ else:
+ load_net_clean[k] = v
+ pretrained_net = load_net_clean
+
+ tbd = []
+ for k, v in crt_net.items():
+ tbd.append(k)
+
+ # directly copy
+ for k, v in crt_net.items():
+ if k in pretrained_net and pretrained_net[k].size() == v.size():
+ crt_net[k] = pretrained_net[k]
+ tbd.remove(k)
+
+ crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
+ crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
+
+ for k in tbd.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[k] = pretrained_net[ori_k]
+ tbd.remove(k)
+
+ crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
+ crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
+ crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
+ crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
+ crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
+ crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
+ crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
+ crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
+ crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
+ crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
+
+ crt_model.load_state_dict(crt_net)
+ crt_model.eval()
+ return crt_model
- crt_net = crt_model.state_dict()
- load_net_clean = {}
- for k, v in pretrained_net.items():
- if k.startswith('module.'):
- load_net_clean[k[7:]] = v
- else:
- load_net_clean[k] = v
- pretrained_net = load_net_clean
-
- tbd = []
- for k, v in crt_net.items():
- tbd.append(k)
-
- # directly copy
- for k, v in crt_net.items():
- if k in pretrained_net and pretrained_net[k].size() == v.size():
- crt_net[k] = pretrained_net[k]
- tbd.remove(k)
-
- crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
- crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
-
- for k in tbd.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[k] = pretrained_net[ori_k]
- tbd.remove(k)
-
- crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
- crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
- crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
- crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
- crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
- crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
- crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
- crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
- crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
- crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
-
- crt_model.load_state_dict(crt_net)
- crt_model.eval()
- return crt_model
def upscale_without_tiling(model, img):
img = np.array(img)
@@ -95,7 +139,7 @@ def esrgan_upscale(model, img): if opts.ESRGAN_tile == 0:
return upscale_without_tiling(model, img)
- grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
newtiles = []
scale_factor = 1
@@ -110,32 +154,7 @@ def esrgan_upscale(model, img): newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow])
- newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = modules.images.combine_grid(newgrid)
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor,
+ grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
+ output = images.combine_grid(newgrid)
return output
-
-
-class UpscalerESRGAN(modules.images.Upscaler):
- def __init__(self, filename, title):
- self.name = title
- self.model = load_model(filename)
-
- def do_upscale(self, img):
- model = self.model.to(shared.device)
- img = esrgan_upscale(model, img)
- return img
-
-
-def load_models(dirname):
- for file in os.listdir(dirname):
- path = os.path.join(dirname, file)
- model_name, extension = os.path.splitext(file)
-
- if extension != '.pt' and extension != '.pth':
- continue
-
- try:
- modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
- except Exception:
- print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/extras.py b/modules/extras.py index c2543fcf..1d4e9fa8 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -40,6 +40,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v outputs = []
for image, image_name in zip(imageArr, imageNameArr):
+ if image is None:
+ return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {}
image = image.convert("RGB")
@@ -65,29 +67,28 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res
- if upscaling_resize != 1.0:
- def upscale(image, scaler_index, resize):
- small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
- pixels = tuple(np.array(small).flatten().tolist())
- key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
+ def upscale(image, scaler_index, resize):
+ small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
+ pixels = tuple(np.array(small).flatten().tolist())
+ key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
- c = cached_images.get(key)
- if c is None:
- upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.upscale(image, image.width * resize, image.height * resize)
- cached_images[key] = c
+ c = cached_images.get(key)
+ if c is None:
+ upscaler = shared.sd_upscalers[scaler_index]
+ c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+ cached_images[key] = c
- return c
+ return c
- info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
- res = upscale(image, extras_upscaler_1, upscaling_resize)
+ info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
+ res = upscale(image, extras_upscaler_1, upscaling_resize)
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- res2 = upscale(image, extras_upscaler_2, upscaling_resize)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
- res = Image.blend(res, res2, extras_upscaler_2_visibility)
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
+ res2 = upscale(image, extras_upscaler_2, upscaling_resize)
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
+ res = Image.blend(res, res2, extras_upscaler_2_visibility)
- image = res
+ image = res
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 44c5dc6c..2bf8a1ee 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,39 +1,25 @@ import os
import sys
import traceback
-from glob import glob
-from modules import shared, devices
-from modules.shared import cmd_opts
-from modules.paths import script_path
-import modules.face_restoration
-
-
-def gfpgan_model_path():
- from modules.shared import cmd_opts
-
- filemask = 'GFPGAN*.pth'
-
- if cmd_opts.gfpgan_model is not None:
- return cmd_opts.gfpgan_model
-
- places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
-
- filename = None
- for place in places:
- filename = next(iter(glob(os.path.join(place, filemask))), None)
- if filename is not None:
- break
-
- return filename
+import facexlib
+import gfpgan
+import modules.face_restoration
+from modules import shared, devices, modelloader
+from modules.paths import models_path
+model_dir = "GFPGAN"
+user_path = None
+model_path = os.path.join(models_path, model_dir)
+model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
+have_gfpgan = False
loaded_gfpgan_model = None
-def gfpgan():
+def gfpgann():
global loaded_gfpgan_model
-
+ global model_path
if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device)
return loaded_gfpgan_model
@@ -41,7 +27,17 @@ def gfpgan(): if gfpgan_constructor is None:
return None
- model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
+ if len(models) == 1 and "http" in models[0]:
+ model_file = models[0]
+ elif len(models) != 0:
+ latest_file = max(models, key=os.path.getctime)
+ model_file = latest_file
+ else:
+ print("Unable to load gfpgan model!")
+ return None
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2,
+ bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
@@ -49,10 +45,12 @@ def gfpgan(): def gfpgan_fix_faces(np_image):
- model = gfpgan()
-
+ model = gfpgann()
+ if model is None:
+ return np_image
np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False,
+ only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
if shared.opts.face_restoration_unload:
@@ -61,21 +59,41 @@ def gfpgan_fix_faces(np_image): return np_image
-have_gfpgan = False
gfpgan_constructor = None
-def setup_gfpgan():
- try:
- gfpgan_model_path()
- if os.path.exists(cmd_opts.gfpgan_dir):
- sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
- from gfpgan import GFPGANer
+def setup_model(dirname):
+ global model_path
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+ try:
+ from gfpgan import GFPGANer
+ from facexlib import detection, parsing
+ global user_path
global have_gfpgan
- have_gfpgan = True
-
global gfpgan_constructor
+
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
+
|