diff options
author | d8ahazard <d8ahazard@gmail.com> | 2022-09-29 22:46:23 +0000 |
---|---|---|
committer | d8ahazard <d8ahazard@gmail.com> | 2022-09-29 22:46:23 +0000 |
commit | 0dce0df1ee63b2f158805c1a1f1a3743cc4a104b (patch) | |
tree | dfcec33656d06835e71961b117b63e510cb9bff2 | |
parent | 31ad536c331df14dd785bfd2a1f93f91a8f7839e (diff) | |
download | stable-diffusion-webui-gfx803-0dce0df1ee63b2f158805c1a1f1a3743cc4a104b.tar.gz stable-diffusion-webui-gfx803-0dce0df1ee63b2f158805c1a1f1a3743cc4a104b.tar.bz2 stable-diffusion-webui-gfx803-0dce0df1ee63b2f158805c1a1f1a3743cc4a104b.zip |
Holy $hit.
Yep.
Fix gfpgan_model_arch requirement(s).
Add Upscaler base class, move from images.
Add a lot of methods to Upscaler.
Re-work all the child upscalers to be proper classes.
Add BSRGAN scaler.
Add ldsr_model_arch class, removing the dependency for another repo that just uses regular latent-diffusion stuff.
Add one universal method that will always find and load new upscaler models without having to add new "setup_model" calls. Still need to add command line params, but that could probably be automated.
Add a "self.scale" property to all Upscalers so the scalers themselves can do "things" in response to the requested upscaling size.
Ensure LDSR doesn't get stuck in a longer loop of "upscale/downscale/upscale" as we try to reach the target upscale size.
Add typehints for IDE sanity.
PEP-8 improvements.
Moar.
-rw-r--r-- | launch.py | 11 | ||||
-rw-r--r-- | modules/bsrgan_model.py | 79 | ||||
-rw-r--r-- | modules/bsrgan_model_arch.py | 103 | ||||
-rw-r--r-- | modules/esrgan_model.py | 227 | ||||
-rw-r--r-- | modules/extras.py | 35 | ||||
-rw-r--r-- | modules/gfpgan_model.py | 58 | ||||
-rw-r--r-- | modules/gfpgan_model_arch.py | 150 | ||||
-rw-r--r-- | modules/images.py | 84 | ||||
-rw-r--r-- | modules/ldsr_model.py | 103 | ||||
-rw-r--r-- | modules/ldsr_model_arch.py | 223 | ||||
-rw-r--r-- | modules/modelloader.py | 74 | ||||
-rw-r--r-- | modules/realesrgan_model.py | 209 | ||||
-rw-r--r-- | modules/sd_models.py | 3 | ||||
-rw-r--r-- | modules/sd_samplers.py | 4 | ||||
-rw-r--r-- | modules/shared.py | 18 | ||||
-rw-r--r-- | modules/swinir_model.py | 157 | ||||
-rw-r--r-- | modules/upscaler.py | 121 | ||||
-rw-r--r-- | webui.py | 9 |
18 files changed, 1018 insertions, 650 deletions
@@ -1,5 +1,5 @@ # this scripts installs necessary requirements and launches main program in webui.py
-
+import shutil
import subprocess
import os
import sys
@@ -22,7 +22,6 @@ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "6 taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
-ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH',"abf33e7002d59d9085081bce93ec798dcabd49af")
args = shlex.split(commandline_args)
@@ -122,9 +121,11 @@ git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-di git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
-# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
-git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash)
-
+if os.path.isdir(repo_dir('latent-diffusion')):
+ try:
+ shutil.rmtree(repo_dir('latent-diffusion'))
+ except:
+ pass
if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py new file mode 100644 index 00000000..77141545 --- /dev/null +++ b/modules/bsrgan_model.py @@ -0,0 +1,79 @@ +import os.path +import sys +import traceback + +import PIL.Image +import numpy as np +import torch +from basicsr.utils.download_util import load_file_from_url + +import modules.upscaler +from modules import shared, modelloader +from modules.bsrgan_model_arch import RRDBNet +from modules.paths import models_path + + +class UpscalerBSRGAN(modules.upscaler.Upscaler): + def __init__(self, dirname): + self.name = "BSRGAN" + self.model_path = os.path.join(models_path, self.name) + self.model_name = "BSRGAN 4x" + self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" + self.user_path = dirname + super().__init__() + model_paths = self.find_models(ext_filter=[".pt", ".pth"]) + scalers = [] + if len(model_paths) == 0: + scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4) + scalers.append(scaler_data) + for file in model_paths: + if "http" in file: + name = self.model_name + else: + name = modelloader.friendly_name(file) + try: + scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) + scalers.append(scaler_data) + except Exception: + print(f"Error loading BSRGAN model: {file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + self.scalers = scalers + + def do_upscale(self, img: PIL.Image, selected_file): + torch.cuda.empty_cache() + model = self.load_model(selected_file) + if model is None: + return img + model.to(shared.device) + torch.cuda.empty_cache() + img = np.array(img) + img = img[:, :, ::-1] + img = np.moveaxis(img, 2, 0) / 255 + img = torch.from_numpy(img).float() + img = img.unsqueeze(0).to(shared.device) + with torch.no_grad(): + output = model(img) + output = output.squeeze().float().cpu().clamp_(0, 1).numpy() + output = 255. * np.moveaxis(output, 0, 2) + output = output.astype(np.uint8) + output = output[:, :, ::-1] + torch.cuda.empty_cache() + return PIL.Image.fromarray(output, 'RGB') + + def load_model(self, path: str): + if "http" in path: + filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, + progress=True) + else: + filename = path + if not os.path.exists(filename) or filename is None: + print("Unable to load %s from %s" % (self.model_dir, filename)) + return None + print("Loading %s from %s" % (self.model_dir, filename)) + model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2) # define network + model.load_state_dict(torch.load(filename), strict=True) + model.eval() + for k, v in model.named_parameters(): + v.requires_grad = False + return model + diff --git a/modules/bsrgan_model_arch.py b/modules/bsrgan_model_arch.py new file mode 100644 index 00000000..d72647db --- /dev/null +++ b/modules/bsrgan_model_arch.py @@ -0,0 +1,103 @@ +import functools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init + + +def initialize_weights(net_l, scale=1): + if not isinstance(net_l, list): + net_l = [net_l] + for net in net_l: + for m in net.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale # for residual block + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, a=0, mode='fan_in') + m.weight.data *= scale + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + init.constant_(m.weight, 1) + init.constant_(m.bias.data, 0.0) + + +def make_layer(block, n_layers): + layers = [] + for _ in range(n_layers): + layers.append(block()) + return nn.Sequential(*layers) + + +class ResidualDenseBlock_5C(nn.Module): + def __init__(self, nf=64, gc=32, bias=True): + super(ResidualDenseBlock_5C, self).__init__() + # gc: growth channel, i.e. intermediate channels + self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) + self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) + self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) + self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) + self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + return x5 * 0.2 + x + + +class RRDB(nn.Module): + '''Residual in Residual Dense Block''' + + def __init__(self, nf, gc=32): + super(RRDB, self).__init__() + self.RDB1 = ResidualDenseBlock_5C(nf, gc) + self.RDB2 = ResidualDenseBlock_5C(nf, gc) + self.RDB3 = ResidualDenseBlock_5C(nf, gc) + + def forward(self, x): + out = self.RDB1(x) + out = self.RDB2(out) + out = self.RDB3(out) + return out * 0.2 + x + + +class RRDBNet(nn.Module): + def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4): + super(RRDBNet, self).__init__() + RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) + self.sf = sf + print([in_nc, out_nc, nf, nb, gc, sf]) + + self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) + self.RRDB_trunk = make_layer(RRDB_block_f, nb) + self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + #### upsampling + self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + if self.sf==4: + self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) + self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + fea = self.conv_first(x) + trunk = self.trunk_conv(self.RRDB_trunk(fea)) + fea = fea + trunk + + fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) + if self.sf==4: + fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.HRconv(fea))) + + return out
\ No newline at end of file diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 5e10c49c..ce841aa4 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,6 +1,4 @@ import os
-import sys
-import traceback
import numpy as np
import torch
@@ -8,93 +6,119 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch
-import modules.images
-from modules import shared
-from modules import shared, modelloader
+from modules import shared, modelloader, images
from modules.devices import has_mps
from modules.paths import models_path
+from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-model_dir = "ESRGAN"
-model_path = os.path.join(models_path, model_dir)
-model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
-model_name = "ESRGAN_x4"
-
-
-def load_model(path: str, name: str):
- global model_path
- global model_url
- global model_dir
- global model_name
- if "http" in path:
- filename = load_file_from_url(url=model_url, model_dir=model_path, file_name="%s.pth" % model_name, progress=True)
- else:
- filename = path
- if not os.path.exists(filename) or filename is None:
- print("Unable to load %s from %s" % (model_dir, filename))
- return None
- print("Loading %s from %s" % (model_dir, filename))
- # this code is adapted from https://github.com/xinntao/ESRGAN
- pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
- crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
-
- if 'conv_first.weight' in pretrained_net:
- crt_model.load_state_dict(pretrained_net)
- return crt_model
- if 'model.0.weight' not in pretrained_net:
- is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
- if is_realesrgan:
- raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
- else:
- raise Exception("The file is not a ESRGAN model.")
+class UpscalerESRGAN(Upscaler):
+ def __init__(self, dirname):
+ self.name = "ESRGAN"
+ self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
+ self.model_name = "ESRGAN 4x"
+ self.scalers = []
+ self.user_path = dirname
+ self.model_path = os.path.join(models_path, self.name)
+ super().__init__()
+ model_paths = self.find_models(ext_filter=[".pt", ".pth"])
+ scalers = []
+ if len(model_paths) == 0:
+ scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
+ scalers.append(scaler_data)
+ for file in model_paths:
+ print(f"File: {file}")
+ if "http" in file:
+ name = self.model_name
+ else:
+ name = modelloader.friendly_name(file)
+
+ scaler_data = UpscalerData(name, file, self, 4)
+ print(f"ESRGAN: Adding scaler {name}")
+ self.scalers.append(scaler_data)
+
+ def do_upscale(self, img, selected_model):
+ model = self.load_model(selected_model)
+ if model is None:
+ return img
+ model.to(shared.device)
+ img = esrgan_upscale(model, img)
+ return img
- crt_net = crt_model.state_dict()
- load_net_clean = {}
- for k, v in pretrained_net.items():
- if k.startswith('module.'):
- load_net_clean[k[7:]] = v
+ def load_model(self, path: str):
+ if "http" in path:
+ filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="%s.pth" % self.model_name,
+ progress=True)
else:
- load_net_clean[k] = v
- pretrained_net = load_net_clean
-
- tbd = []
- for k, v in crt_net.items():
- tbd.append(k)
-
- # directly copy
- for k, v in crt_net.items():
- if k in pretrained_net and pretrained_net[k].size() == v.size():
- crt_net[k] = pretrained_net[k]
- tbd.remove(k)
-
- crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
- crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
-
- for k in tbd.copy():
- if 'RDB' in k:
- ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
- if '.weight' in k:
- ori_k = ori_k.replace('.weight', '.0.weight')
- elif '.bias' in k:
- ori_k = ori_k.replace('.bias', '.0.bias')
- crt_net[k] = pretrained_net[ori_k]
- tbd.remove(k)
-
- crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
- crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
- crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
- crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
- crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
- crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
- crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
- crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
- crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
- crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
-
- crt_model.load_state_dict(crt_net)
- crt_model.eval()
- return crt_model
+ filename = path
+ if not os.path.exists(filename) or filename is None:
+ print("Unable to load %s from %s" % (self.model_path, filename))
+ return None
+ # this code is adapted from https://github.com/xinntao/ESRGAN
+ pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
+ crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
+
+ if 'conv_first.weight' in pretrained_net:
+ crt_model.load_state_dict(pretrained_net)
+ return crt_model
+
+ if 'model.0.weight' not in pretrained_net:
+ is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
+ "params_ema"]
+ if is_realesrgan:
+ raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
+ else:
+ raise Exception("The file is not a ESRGAN model.")
+
+ crt_net = crt_model.state_dict()
+ load_net_clean = {}
+ for k, v in pretrained_net.items():
+ if k.startswith('module.'):
+ load_net_clean[k[7:]] = v
+ else:
+ load_net_clean[k] = v
+ pretrained_net = load_net_clean
+
+ tbd = []
+ for k, v in crt_net.items():
+ tbd.append(k)
+
+ # directly copy
+ for k, v in crt_net.items():
+ if k in pretrained_net and pretrained_net[k].size() == v.size():
+ crt_net[k] = pretrained_net[k]
+ tbd.remove(k)
+
+ crt_net['conv_first.weight'] = pretrained_net['model.0.weight']
+ crt_net['conv_first.bias'] = pretrained_net['model.0.bias']
+
+ for k in tbd.copy():
+ if 'RDB' in k:
+ ori_k = k.replace('RRDB_trunk.', 'model.1.sub.')
+ if '.weight' in k:
+ ori_k = ori_k.replace('.weight', '.0.weight')
+ elif '.bias' in k:
+ ori_k = ori_k.replace('.bias', '.0.bias')
+ crt_net[k] = pretrained_net[ori_k]
+ tbd.remove(k)
+
+ crt_net['trunk_conv.weight'] = pretrained_net['model.1.sub.23.weight']
+ crt_net['trunk_conv.bias'] = pretrained_net['model.1.sub.23.bias']
+ crt_net['upconv1.weight'] = pretrained_net['model.3.weight']
+ crt_net['upconv1.bias'] = pretrained_net['model.3.bias']
+ crt_net['upconv2.weight'] = pretrained_net['model.6.weight']
+ crt_net['upconv2.bias'] = pretrained_net['model.6.bias']
+ crt_net['HRconv.weight'] = pretrained_net['model.8.weight']
+ crt_net['HRconv.bias'] = pretrained_net['model.8.bias']
+ crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
+ crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
+
+ crt_model.load_state_dict(crt_net)
+ crt_model.eval()
+ return crt_model
+
def upscale_without_tiling(model, img):
img = np.array(img)
@@ -115,7 +139,7 @@ def esrgan_upscale(model, img): if opts.ESRGAN_tile == 0:
return upscale_without_tiling(model, img)
- grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
+ grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
newtiles = []
scale_factor = 1
@@ -130,38 +154,7 @@ def esrgan_upscale(model, img): newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow])
- newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
- output = modules.images.combine_grid(newgrid)
+ newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor,
+ grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
+ output = images.combine_grid(newgrid)
return output
-
-
-class UpscalerESRGAN(modules.images.Upscaler):
- def __init__(self, filename, title):
- self.name = title
- self.filename = filename
-
- def do_upscale(self, img):
- model = load_model(self.filename, self.name)
- if model is None:
- return img
- model.to(shared.device)
- img = esrgan_upscale(model, img)
- return img
-
-
-def setup_model(dirname):
- global model_path
- global model_name
- if not os.path.exists(model_path):
- os.makedirs(model_path)
-
- model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
- if len(model_paths) == 0:
- modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
- for file in model_paths:
- name = modelloader.friendly_name(file)
- try:
- modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
- except Exception:
- print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/extras.py b/modules/extras.py index af6e631f..d7d0fa54 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -66,29 +66,28 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res
- if upscaling_resize != 1.0:
- def upscale(image, scaler_index, resize):
- small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
- pixels = tuple(np.array(small).flatten().tolist())
- key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
+ def upscale(image, scaler_index, resize):
+ small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
+ pixels = tuple(np.array(small).flatten().tolist())
+ key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
- c = cached_images.get(key)
- if c is None:
- upscaler = shared.sd_upscalers[scaler_index]
- c = upscaler.upscale(image, image.width * resize, image.height * resize)
- cached_images[key] = c
+ c = cached_images.get(key)
+ if c is None:
+ upscaler = shared.sd_upscalers[scaler_index]
+ c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
+ cached_images[key] = c
- return c
+ return c
- info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
- res = upscale(image, extras_upscaler_1, upscaling_resize)
+ info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
+ res = upscale(image, extras_upscaler_1, upscaling_resize)
- if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
- res2 = upscale(image, extras_upscaler_2, upscaling_resize)
- info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
- res = Image.blend(res, res2, extras_upscaler_2_visibility)
+ if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
+ res2 = upscale(image, extras_upscaler_2, upscaling_resize)
+ info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
+ res = Image.blend(res, res2, extras_upscaler_2_visibility)
- image = res
+ image = res
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index ffb6960d..2bf8a1ee 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,24 +1,23 @@ import os
import sys
import traceback
-from glob import glob
-from modules import shared, devices
-from modules.shared import cmd_opts
-from modules.paths import script_path
+import facexlib
+import gfpgan
+
import modules.face_restoration
from modules import shared, devices, modelloader
from modules.paths import models_path
model_dir = "GFPGAN"
-cmd_dir = None
+user_path = None
model_path = os.path.join(models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
-
+have_gfpgan = False
loaded_gfpgan_model = None
-def gfpgan():
+def gfpgann():
global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None:
@@ -28,14 +27,16 @@ def gfpgan(): if gfpgan_constructor is None:
return None
- models = modelloader.load_models(model_path, model_url, cmd_dir)
- if len(models) != 0:
+ models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
+ if len(models) == 1 and "http" in models[0]:
+ model_file = models[0]
+ elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
model_file = latest_file
else:
print("Unable to load gfpgan model!")
return None
- model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2,
+ model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2,
bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
@@ -44,11 +45,12 @@ def gfpgan(): def gfpgan_fix_faces(np_image):
- model = gfpgan()
+ model = gfpgann()
if model is None:
return np_image
np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False,
+ only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
if shared.opts.face_restoration_unload:
@@ -57,7 +59,6 @@ def gfpgan_fix_faces(np_image): return np_image
-have_gfpgan = False
gfpgan_constructor = None
@@ -67,14 +68,33 @@ def setup_model(dirname): os.makedirs(model_path)
try:
- from modules.gfpgan_model_arch import GFPGANerr
- global cmd_dir
+ from gfpgan import GFPGANer
+ from facexlib import detection, parsing
+ global user_path
global have_gfpgan
global gfpgan_constructor
- cmd_dir = dirname
+ load_file_from_url_orig = gfpgan.utils.load_file_from_url
+ facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
+ facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
+
+ def my_load_file_from_url(**kwargs):
+ print("Setting model_dir to " + model_path)
+ return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
+
+ def facex_load_file_from_url(**kwargs):
+ return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ def facex_load_file_from_url2(**kwargs):
+ return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
+
+ gfpgan.utils.load_file_from_url = my_load_file_from_url
+ facexlib.detection.load_file_from_url = facex_load_file_from_url
+ facexlib.parsing.load_file_from_url = facex_load_file_from_url2
+ user_path = dirname
+ print("Have gfpgan should be true?")
have_gfpgan = True
- gfpgan_constructor = GFPGANerr
+ gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self):
@@ -82,7 +102,9 @@ def setup_model(dirname): def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1]
- cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False,
+ only_center_face=False,
+ paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image
diff --git a/modules/gfpgan_model_arch.py b/modules/gfpgan_model_arch.py deleted file mode 100644 index d81cea96..00000000 --- a/modules/gfpgan_model_arch.py +++ /dev/null @@ -1,150 +0,0 @@ -# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original... - -import cv2 -import os -import torch -from basicsr.utils import img2tensor, tensor2img -from basicsr.utils.download_util import load_file_from_url -from facexlib.utils.face_restoration_helper import FaceRestoreHelper -from torchvision.transforms.functional import normalize - -from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear -from gfpgan.archs.gfpganv1_arch import GFPGANv1 -from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean - -ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - - -class GFPGANerr(): - """Helper for restoration with GFPGAN. - - It will detect and crop faces, and then resize the faces to 512x512. - GFPGAN is used to restored the resized faces. - The background is upsampled with the bg_upsampler. - Finally, the faces will be pasted back to the upsample background image. - - Args: - model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). - upscale (float): The upscale of the final output. Default: 2. - arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. - channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. - bg_upsampler (nn.Module): The upsampler for the background. Default: None. - """ - - def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None): - self.upscale = upscale - self.bg_upsampler = bg_upsampler - - # initialize model - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device - # initialize the GFP-GAN - if arch == 'clean': - self.gfpgan = GFPGANv1Clean( - out_size=512, - num_style_feat=512, - channel_multiplier=channel_multiplier, - decoder_load_path=None, - fix_decoder=False, - num_mlp=8, - input_is_latent=True, - different_w=True, - narrow=1, - sft_half=True) - elif arch == 'bilinear': - self.gfpgan = GFPGANBilinear( - out_size=512, - num_style_feat=512, - channel_multiplier=channel_multiplier, - decoder_load_path=None, - fix_decoder=False, - num_mlp=8, - input_is_latent=True, - different_w=True, - narrow=1, - sft_half=True) - elif arch == 'original': - self.gfpgan = GFPGANv1( - out_size=512, - num_style_feat=512, - channel_multiplier=channel_multiplier, - decoder_load_path=None, - fix_decoder=True, - num_mlp=8, - input_is_latent=True, - different_w=True, - narrow=1, - sft_half=True) - elif arch == 'RestoreFormer': - from gfpgan.archs.restoreformer_arch import RestoreFormer - self.gfpgan = RestoreFormer() - # initialize face helper - self.face_helper = FaceRestoreHelper( - upscale, - face_size=512, - crop_ratio=(1, 1), - det_model='retinaface_resnet50', - save_ext='png', - use_parse=True, - device=self.device, - model_rootpath=model_dir) - - if model_path.startswith('https://'): - model_path = load_file_from_url( - url=model_path, model_dir=model_dir, progress=True, file_name=None) - loadnet = torch.load(model_path) - if 'params_ema' in loadnet: - keyname = 'params_ema' - else: - keyname = 'params' - self.gfpgan.load_state_dict(loadnet[keyname], strict=True) - self.gfpgan.eval() - self.gfpgan = self.gfpgan.to(self.device) - - @torch.no_grad() - def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5): - self.face_helper.clean_all() - - if has_aligned: # the inputs are already aligned - img = cv2.resize(img, (512, 512)) - self.face_helper.cropped_faces = [img] - else: - self.face_helper.read_image(img) - # get face landmarks for each face - self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5) - # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels - # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations. - # align and warp each face - self.face_helper.align_warp_face() - - # face restoration - for cropped_face in self.face_helper.cropped_faces: - # prepare data - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) - cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) - - try: - output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0] - # convert to image - restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) - except RuntimeError as error: - print(f'\tFailed inference for GFPGAN: {error}.') - restored_face = cropped_face - - restored_face = restored_face.astype('uint8') - self.face_helper.add_restored_face(restored_face) - - if not has_aligned and paste_back: - # upsample the background - if self.bg_upsampler is not None: - # Now only support RealESRGAN for upsampling background - bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] - else: - bg_img = None - - self.face_helper.get_inverse_affine(None) - # paste each restored face to the input image - restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img) - return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img - else: - return self.face_helper.cropped_faces, self.face_helper.restored_faces, None diff --git a/modules/images.py b/modules/images.py index 9458bf8d..a6538dbe 100644 --- a/modules/images.py +++ b/modules/images.py @@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin from fonts.ttf import Roboto
import string
-import modules.shared
from modules import sd_samplers, shared
from modules.shared import opts, cmd_opts
@@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): cols = math.ceil((w - overlap) / non_overlap_width)
rows = math.ceil((h - overlap) / non_overlap_height)
- dx = (w - tile_w) / (cols-1) if cols > 1 else 0
- dy = (h - tile_h) / (rows-1) if rows > 1 else 0
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows):
@@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): for col in range(cols):
x = int(col * dx)
- if x+tile_w >= w:
+ if x + tile_w >= w:
x = w - tile_w
tile = image.crop((x, y, x + tile_w, y + tile_h))
@@ -85,8 +84,10 @@ def combine_grid(grid): r = r.astype(np.uint8)
return Image.fromarray(r, 'L')
- mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
- mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
+ mask_w = make_mask_image(
+ np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
+ mask_h = make_mask_image(
+ np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
for y, h, row in grid.tiles:
@@ -129,10 +130,12 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): def draw_texts(drawing, draw_x, draw_y, lines):
for i, line in enumerate(lines):
- drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt,
+ fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active:
- drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
+ drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2,
+ draw_y + line.size[1] // 2), fill=color_inactive, width=4)
draw_y += line.size[1] + line_spacing
@@ -171,7 +174,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
- ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
+ ver_texts]
pad_top = max(hor_text_heights) + line_spacing * 2
@@ -202,8 +206,10 @@ def draw_prompt_matrix(im, width, height, all_prompts): prompts_horiz = prompts[:boundary]
prompts_vert = prompts[boundary:]
- hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
- ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in
+ range(1 << len(prompts_horiz))]
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in
+ range(1 << len(prompts_vert))]
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
@@ -214,7 +220,8 @@ def resize_image(resize_mode, im, width, height): return im.resize((w, h), resample=LANCZOS)
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
- return upscaler.upscale(im, w, h)
+ scale = w / im.width
+ return upscaler.scaler.upscale(im, scale)
if resize_mode == 0:
res = resize(im, width, height)
@@ -244,11 +251,13 @@ def resize_image(resize_mode, im, width, height): if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
- res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
+ box=(0, fill_height + src_h))
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
- res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
+ box=(fill_width + src_w, 0))
return res
@@ -256,7 +265,7 @@ def resize_image(resize_mode, im, width, height): invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
-re_nonletters = re.compile(r'[\s'+string.punctuation+']+')
+re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
max_filename_part_length = 128
@@ -283,7 +292,8 @@ def apply_filename_pattern(x, p, seed, prompt): words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
if len(words) == 0:
words = ["empty"]
- x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
+ x = x.replace("[prompt_words]",
+ sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
if p is not None:
x = x.replace("[steps]", str(p.steps))
@@ -291,7 +301,8 @@ def apply_filename_pattern(x, p, seed, prompt): x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
- x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
+ x = x.replace("[sampler]",
+ sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
x = x.replace("[date]", datetime.date.today().isoformat())
@@ -303,6 +314,7 @@ def apply_filename_pattern(x, p, seed, prompt): return x
+
def get_next_sequence_number(path, basename):
"""
Determines and returns the next sequence number to use when saving an image in the specified directory.
@@ -316,7 +328,8 @@ def get_next_sequence_number(path, basename): prefix_length = len(basename)
for p in os.listdir(path):
if p.startswith(basename):
- l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
+ l = os.path.splitext(p[prefix_length:])[0].split(
+ '-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try:
result = max(int(l[0]), result)
except ValueError:
@@ -324,7 +337,10 @@ def get_next_sequence_number(path, basename): return result + 1
-def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
+
+def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False,
+ no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None,
+ forced_filename=None, suffix=""):
if short_filename or prompt is None or seed is None:
file_decoration = ""
elif opts.save_to_dirs:
@@ -361,7 +377,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i fullfn = "a.png"
fullfn_without_extension = "a"
for i in range(500):
- fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}"
+ fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
if not os.path.exists(fullfn):
@@ -403,31 +419,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i file.write(info + "\n")
-class Upscaler:
- name = "Lanczos"
-
- def do_upscale(self, img):
- return img
-
- def upscale(self, img, w, h):
- for i in range(3):
- if img.width >= w and img.height >= h:
- break
-
- img = self.do_upscale(img)
-
- if img.width != w or img.height != h:
- img = img.resize((int(w), int(h)), resample=LANCZOS)
-
- return img
-
-
-class UpscalerNone(Upscaler):
- name = "None"
-
- def upscale(self, img, w, h):
- return img
-
-
-modules.shared.sd_upscalers.append(UpscalerNone())
-modules.shared.sd_upscalers.append(Upscaler())
diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py index 4f9b1657..969d1a0d 100644 --- a/modules/ldsr_model.py +++ b/modules/ldsr_model.py @@ -1,74 +1,45 @@ import os import sys import traceback -from collections import namedtuple -from modules import shared, images, modelloader, paths -from modules.paths import models_path - -model_dir = "LDSR" -model_path = os.path.join(models_path, model_dir) -cmd_path = None -model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" -yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" +from basicsr.utils.download_util import load_file_from_url -LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"]) - -ldsr_models = [] -have_ldsr = False -LDSR_obj = None +from modules.upscaler import Upscaler, UpscalerData +from modules.ldsr_model_arch import LDSR +from modules import shared +from modules.paths import models_path -class UpscalerLDSR(images.Upscaler): - def __init__(self, steps): - self.steps = steps +class UpscalerLDSR(Upscaler): + def __init__(self, user_path): self.name = "LDSR" - - def do_upscale(self, img): - return upscale_with_ldsr(img) - - -def setup_model(dirname): - global cmd_path - global model_path - if not os.path.exists(model_path): - os.makedirs(model_path) - cmd_path = dirname - shared.sd_upscalers.append(UpscalerLDSR(100)) - - -def prepare_ldsr(): - path = paths.paths.get("LDSR", None) - if path is None: - return - global have_ldsr - global LDSR_obj - try: - from LDSR import LDSR - model_files = modelloader.load_models(model_path, model_url, cmd_path, dl_name="model.ckpt", ext_filter=[".ckpt"]) - yaml_files = modelloader.load_models(model_path, yaml_url, cmd_path, dl_name="project.yaml", ext_filter=[".yaml"]) - if len(model_files) != 0 and len(yaml_files) != 0: - model_file = model_files[0] - yaml_file = yaml_files[0] - have_ldsr = True - LDSR_obj = LDSR(model_file, yaml_file) - else: - return - - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - have_ldsr = False - - -def upscale_with_ldsr(image): - prepare_ldsr() - if not have_ldsr or LDSR_obj is None: - return image - - ddim_steps = shared.opts.ldsr_steps - pre_scale = shared.opts.ldsr_pre_down - post_scale = shared.opts.ldsr_post_down - - image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale) - return image + self.model_path = os.path.join(models_path, self.name) + self.user_path = user_path + self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" + self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" + super().__init__() + scaler_data = UpscalerData("LDSR", None, self) + self.scalers = [scaler_data] + + def load_model(self, path: str): + model = load_file_from_url(url=self.model_url, model_dir=self.model_path, + file_name="model.pth", progress=True) + yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path, + file_name="project.yaml", progress=True) + + try: + return LDSR(model, yaml) + + except Exception: + print("Error importing LDSR:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + def do_upscale(self, img, path): + ldsr = self.load_model(path) + if ldsr is None: + print("NO LDSR!") + return img + ddim_steps = shared.opts.ldsr_steps + pre_scale = shared.opts.ldsr_pre_down + return ldsr.super_resolution(img, ddim_steps, self.scale) diff --git a/modules/ldsr_model_arch.py b/modules/ldsr_model_arch.py new file mode 100644 index 00000000..8fe87c6a --- /dev/null +++ b/modules/ldsr_model_arch.py @@ -0,0 +1,223 @@ +import gc +import time +import warnings + +import numpy as np +import torch +import torchvision +from PIL import Image +from einops import rearrange, repeat +from omegaconf import OmegaConf + +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.util import instantiate_from_config, ismap + +warnings.filterwarnings("ignore", category=UserWarning) + + +# Create LDSR Class +class LDSR: + def load_model_from_config(self, half_attention): + print(f"Loading model from {self.modelPath}") + pl_sd = torch.load(self.modelPath, map_location="cpu") + sd = pl_sd["state_dict"] + config = OmegaConf.load(self.yamlPath) + model = instantiate_from_config(config.model) + model.load_state_dict(sd, strict=False) + model.cuda() + if half_attention: + model = model.half() + + model.eval() + return {"model": model} + + def __init__(self, model_path, yaml_path): + self.modelPath = model_path + self.yamlPath = yaml_path + + @staticmethod + def run(model, selected_path, custom_steps, eta): + example = get_cond(selected_path) + + n_runs = 1 + guider = None + ckwargs = None + ddim_use_x0_pred = False + temperature = 1. + eta = eta + custom_shape = None + + height, width = example["image"].shape[1:3] + split_input = height >= 128 and width >= 128 + + if split_input: + ks = 128 + stride = 64 + vqf = 4 # + model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride), + "vqf": vqf, + "patch_distributed_vq": True, + "tie_braker": False, + "clip_max_weight": 0.5, + "clip_min_weight": 0.01, + "clip_max_tie_weight": 0.5, + "clip_min_tie_weight": 0.01} + else: + if hasattr(model, "split_input_params"): + delattr(model, "split_input_params") + + x_t = None + logs = None + for n in range(n_runs): + if custom_shape is not None: + x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device) + x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0]) + + logs = make_convolutional_sample(example, model, + custom_steps=custom_steps, + eta=eta, quantize_x0=False, + custom_shape=custom_shape, + temperature=temperature, noise_dropout=0., + corrector=guider, corrector_kwargs=ckwargs, x_T=x_t, + ddim_use_x0_pred=ddim_use_x0_pred + ) + return logs + + def super_resolution(self, image, steps=100, target_scale=2, half_attention=False): + model = self.load_model_from_config(half_attention) + + # Run settings + diffusion_steps = int(steps) + eta = 1.0 + + down_sample_method = 'Lanczos' + + gc.collect() + torch.cuda.empty_cache() + + im_og = image + width_og, height_og = im_og.size + # If we can adjust the max upscale size, then the 4 below should be our variable + print("Foo") + down_sample_rate = target_scale / 4 + print(f"Downsample rate is {down_sample_rate}") + width_downsampled_pre = width_og * down_sample_rate + height_downsampled_pre = height_og * down_sample_method + + if down_sample_rate != 1: + print( + f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]') + im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) + else: + print(f"Down sample rate is 1 from {target_scale} / 4") + logs = self.run(model["model"], im_og, diffusion_steps, eta) + + sample = logs["sample"] + sample = sample.detach().cpu() + sample = torch.clamp(sample, -1., 1.) + sample = (sample + 1.) / 2. * 255 + sample = sample.numpy().astype(np.uint8) + sample = np.transpose(sample, (0, 2, 3, 1)) + a = Image.fromarray(sample[0]) + + del model + gc.collect() + torch.cuda.empty_cache() + print(f'Processing finished!') + return a + + +def get_cond(selected_path): + example = dict() + up_f = 4 + c = selected_path.convert('RGB') + c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0) + c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], + antialias=True) + c_up = rearrange(c_up, '1 c h w -> 1 h w c') + c = rearrange(c, '1 c h w -> 1 h w c') + c = 2. * c - 1. + + c = c.to(torch.device("cuda")) + example["LR_image"] = c + example["image"] = c_up + + return example + + +@torch.no_grad() +def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None, + mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None, + corrector_kwargs=None, x_t=None + ): + ddim = DDIMSampler(model) + bs = shape[0] + shape = shape[1:] + print(f"Sampling with eta = {eta}; steps: {steps}") + samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback, + normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta, + mask=mask, x0=x0, temperature=temperature, verbose=False, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, x_t=x_t) + + return samples, intermediates + + +@torch.no_grad() +def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None, + corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False): + log = dict() + + z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=not (hasattr(model, 'split_input_params') + and model.cond_stage_key == 'coordinates_bbox'), + return_original_cond=True) + + if custom_shape is not None: + z = torch.randn(custom_shape) + print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}") + + z0 = None + + log["input"] = x + log["reconstruction"] = xrec + + if ismap(xc): + log["original_conditioning"] = model.to_rgb(xc) + if hasattr(model, 'cond_stage_key'): + log[model.cond_stage_key] = model.to_rgb(xc) + + else: + log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_model: + log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x) + if model.cond_stage_key == 'class_label': + log[model.cond_stage_key] = xc[model.cond_stage_key] + + with model.ema_scope("Plotting"): + t0 = time.time() + + sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape, + eta=eta, + quantize_x0=quantize_x0, mask=None, x0=z0, + temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs, + x_t=x_T) + t1 = time.time() + + if ddim_use_x0_pred: + sample = intermediates['pred_x0'][-1] + + x_sample = model.decode_first_stage(sample) + + try: + x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True) + log["sample_noquant"] = x_sample_noquant + log["sample_diff"] = torch.abs(x_sample_noquant - x_sample) + except: + pass + + log["sample"] = x_sample + log["time"] = t1 - t0 + + return log diff --git a/modules/modelloader.py b/modules/modelloader.py index 3bd1de4d..6de65c69 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,34 +1,36 @@ import os import shutil +import importlib from urllib.parse import urlparse from basicsr.utils.download_util import load_file_from_url +from modules import shared +from modules.upscaler import Upscaler from modules.paths import script_path, models_path -def load_models(model_path: str, model_url: str = None, command_path: str = None, dl_name: str = None, existing=None, - ext_filter=None) -> list: +def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. - @param dl_name: The file name to use for downloading a model. If not specified, it will be used from the URL. - @param model_url: If specified, attempt to download model from the given URL. + @param download_name: Specify to download from model_url immediately. + @param model_url: If no other models are found, this will be downloaded on upscale. @param model_path: The location to store/find models in. @param command_path: A command-line argument to search for models in first. - @param existing: An array of existing model paths. @param ext_filter: An optional list of filename extensions to filter by @return: A list of paths containing the desired model(s) """ + output = [] + if ext_filter is None: ext_filter = [] - if existing is None: - existing = [] try: places = [] if command_path is not None and command_path != model_path: pretrained_path = os.path.join(command_path, 'experiments/pretrained_models') if os.path.exists(pretrained_path): + print(f"Appending path: {pretrained_path}") places.append(pretrained_path) elif os.path.exists(command_path): places.append(command_path) @@ -36,26 +38,24 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None for place in places: if os.path.exists(place): for file in os.listdir(place): - if os.path.isdir(file): + full_path = os.path.join(place, file) + if os.path.isdir(full_path): continue if len(ext_filter) != 0: model_name, extension = os.path.splitext(file) if extension not in ext_filter: continue - if file not in existing: - path = os.path.join(place, file) - existing.append(path) - if model_url is not None and len(existing) == 0: - if dl_name is not None: - model_file = load_file_from_url(url=model_url, model_dir=model_path, file_name=dl_name, progress=True) + if file not in output: + output.append(full_path) + if model_url is not None and len(output) == 0: + if download_name is not None: + dl = load_file_from_url(model_url, model_path, True, download_name) + output.append(dl) else: - model_file = load_file_from_url(url=model_url, model_dir=model_path, progress=True) - - if os.path.exists(model_file) and os.path.isfile(model_file) and model_file not in existing: - existing.append(model_file) + output.append(model_url) except: pass - return existing + return output def friendly_name(file: str): @@ -110,4 +110,38 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): print(f"Removing empty folder: {src_path}") shutil.rmtree(src_path, True) except: - pass
\ No newline at end of file + pass + + +def load_upscalers(): + datas = [] + for cls in Upscaler.__subclasses__(): + name = cls.__name__ + module_name = cls.__module__ + print(f"Class: {name} and {module_name}") + module = importlib.import_module(module_name) + class_ = getattr(module, name) + cmd_name = f"{name.lower().replace('upscaler', '')}-models-path" + print(f"CMD Name: {cmd_name}") + opt_string = None + try: + opt_string = shared.opts.__getattr__(cmd_name) + except: + pass + scaler = class_(opt_string) + for child in scaler.scalers: + print(f"Appending {child.name}") + datas.append(child) + + shared.sd_upscalers = datas + + # for scaler in subclasses: + # print(f"Found scaler: {type(scaler).__name__}") + # try: + # scaler = scaler() + # for child in scaler.scalers: + # print(f"Appending {child.name}") + # datas.append[child] + # except: + # pass + # shared.sd_upscalers = datas diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 458bf678..0a2eb896 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,64 +1,135 @@ import os
import sys
import traceback
-from collections import namedtuple
import numpy as np
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
-import modules.images
+from modules.upscaler import Upscaler, UpscalerData
from modules.paths import models_path
from modules.shared import cmd_opts, opts
-model_dir = "RealESRGAN"
-model_path = os.path.join(models_path, model_dir)
-cmd_dir = None
-RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
-realesrgan_models = []
-have_realesrgan = False
+class UpscalerRealESRGAN(Upscaler):
+ def __init__(self, path):
+ self.name = "RealESRGAN"
+ self.model_path = os.path.join(models_path, self.name)
+ self.user_path = path
+ super().__init__()
+ try:
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ from realesrgan import RealESRGANer
+ from realesrgan.archs.srvgg_arch import SRVGGNetCompact
+ self.enable = True
+ self.scalers = []
+ scalers = self.load_models(path)
+ for scaler in scalers:
+ if scaler.name in opts.realesrgan_enabled_models:
+ self.scalers.append(scaler)
+
+ except Exception:
+ print("Error importing Real-ESRGAN:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ self.enable = False
+ self.scalers = []
+
+ def do_upscale(self, img, path):
+ if not self.enable:
+ return img
+
+ info = self.load_model(path)
+ if not os.path.exists(info.data_path):
+ print("Unable to load RealESRGAN model: %s" % info.name)
+ return img
+
+ upsampler = RealESRGANer(
+ scale=info.scale,
+ model_path=info.data_path,
+ model=info.model(),
+ half=not cmd_opts.no_half,
+ tile=opts.ESRGAN_tile,
+ tile_pad=opts.ESRGAN_tile_overlap,
+ )
+
+ upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
+
+ image = Image.fromarray(upsampled)
+ return image
+
+ def load_model(self, path):
+ try:
+ info = None
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ info = scaler
+
+ if info is None:
+ print(f"Unable to find model info: {path}")
+ return None
+
+ model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
+ info.data_path = model_file
+ return info
+ except Exception as e:
+ print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
-def get_realesrgan_models():
+ def load_models(self, _):
+ return get_realesrgan_models(self)
+
+
+def get_realesrgan_models(scaler):
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [
- RealesrganModelInfo(
- name="Real-ESRGAN General x4x3",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
- netscale=4,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
+ UpscalerData(
+ name="R-ESRGAN General 4xV3",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3"
+ ".pth",
+ scale=4,
+ upscaler=scaler,
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
+ act_type='prelu')
),
- RealesrganModelInfo(
- name="Real-ESRGAN General WDN x4x3",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
- netscale=4,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
+ UpscalerData(
+ name="R-ESRGAN General WDN 4xV3",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
+ scale=4,
+ upscaler=scaler,
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4,
+ act_type='prelu')
),
- RealesrganModelInfo(
- name="Real-ESRGAN AnimeVideo",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
- netscale=4,
- model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
+ UpscalerData(
+ name="R-ESRGAN AnimeVideo",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
+ scale=4,
+ upscaler=scaler,
+ model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4,
+ act_type='prelu')
),
- RealesrganModelInfo(
- name="Real-ESRGAN 4x plus",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
- netscale=4,
+ UpscalerData(
+ name="R-ESRGAN 4x+",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
+ scale=4,
+ upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
),
- RealesrganModelInfo(
- name="Real-ESRGAN 4x plus anime 6B",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
- netscale=4,
+ UpscalerData(
+ name="R-ESRGAN 4x+ Anime6B",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
+ scale=4,
+ upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
),
- RealesrganModelInfo(
- name="Real-ESRGAN 2x plus",
- location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
- netscale=2,
+ UpscalerData(
+ name="R-ESRGAN 2x+",
+ path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
+ scale=2,
+ upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
),
]
@@ -66,69 +137,3 @@ def get_realesrgan_models(): except Exception as e:
print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
-
-
-class UpscalerRealESRGAN(modules.images.Upscaler):
- def __init__(self, upscaling, model_index):
- self.upscaling = upscaling
- self.model_index = model_index
- self.name = realesrgan_models[model_index].name
-
- def do_upscale(self, img):
- return upscale_with_realesrgan(img, self.upscaling, self.model_index)
-
-
-def setup_model(dirname):
- global model_path
- if not os.path.exists(model_path):
- os.makedirs(model_path)
-
- global realesrgan_models
- global have_realesrgan
- if model_path != dirname:
- model_path = dirname
- try:
- from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan import RealESRGANer
- from realesrgan.archs.srvgg_arch import SRVGGNetCompact
-
- realesrgan_models = get_realesrgan_models()
- have_realesrgan = True
-
- for i, model in enumerate(realesrgan_models):
- if model.name in opts.realesrgan_enabled_models:
- modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
-
- except Exception:
- print("Error importing Real-ESRGAN:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
-
- realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
- have_realesrgan = False
-
-
-def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
- if not have_realesrgan:
- return image
-
- info = realesrgan_models[RealESRGAN_model_index]
-
- model = info.model()
- model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True)
- if not os.path.exists(model_file):
- print("Unable to load RealESRGAN model: %s" % info.name)
- return image
-
- upsampler = RealESRGANer(
- scale=info.netscale,
- model_path=info.location,
- model=model,
- half=not cmd_opts.no_half,
- tile=opts.ESRGAN_tile,
- tile_pad=opts.ESRGAN_tile_overlap,
- )
-
- upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
-
- image = Image.fromarray(upsampled)
- return image
diff --git a/modules/sd_models.py b/modules/sd_models.py index 89b7d276..23826727 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -50,7 +50,7 @@ def setup_model(dirname): if not os.path.exists(model_path):
os.makedirs(model_path)
checkpoints_list.clear()
- model_list = modelloader.load_models(model_path, model_url, dirname, model_name, ext_filter=".ckpt")
+ model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=dirname, download_name=model_name, ext_filter=".ckpt")
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
@@ -68,6 +68,7 @@ def setup_model(dirname): def model_hash(filename):
try:
+ print(f"Opening: {filename}")
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 666ee1ee..cfc3ee40 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -154,9 +154,9 @@ class VanillaStableDiffusionSampler: # existing code fails with cetin step counts, like 9
try:
- samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
+ samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_t=x, eta=p.ddim_eta)
except Exception:
- samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
+ samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_t=x, eta=p.ddim_eta)
return samples_ddim
diff --git a/modules/shared.py b/modules/shared.py index c27079eb..4c31039d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,18 +1,19 @@ -import sys
import argparse
+import datetime
import json
import os
+import sys
+
import gradio as gr
import tqdm
-import datetime
import modules.artists
-from modules.paths import script_path, sd_path
-from modules.devices import get_optimal_device
-import modules.styles
import modules.interrogate
import modules.memmon
import modules.sd_models
+import modules.styles
+from modules.devices import get_optimal_device
+from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
@@ -38,6 +39,7 @@ parser.add_argument("--share", action='store_true', help="use share=True for gra parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
+parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
parser.add_argument("--stablediffusion-models-path", type=str, help="Path to directory with Stable-diffusion checkpoints.", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
@@ -111,7 +113,7 @@ face_restorers = [] def realesrgan_models_names():
import modules.realesrgan_model
- return [x.name for x in modules.realesrgan_model.get_realesrgan_models()]
+ return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
class OptionInfo:
@@ -176,13 +178,11 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
- "realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
+ "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
- "ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
-
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
}))
diff --git a/modules/swinir_model.py b/modules/swinir_model.py index f515779e..ea7b6301 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -1,92 +1,91 @@ import contextlib import os -import sys -import traceback import numpy as np import torch from PIL import Image from basicsr.utils.download_util import load_file_from_url -import modules.images from modules import modelloader from modules.paths import models_path from modules.shared import cmd_opts, opts, device from modules.swinir_model_arch import SwinIR as net +from modules.upscaler import Upscaler, UpscalerData -model_dir = "SwinIR" -model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" -model_name = "SwinIR x4" -model_path = os.path.join(models_path, model_dir) -cmd_path = "" precision_scope = ( torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext ) -def load_model(path, scale=4): - global model_path - global model_name - if "http" in path: - dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth") - filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True) - else: - filename = path - if filename is None or not os.path.exists(filename): - return None - model = net( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], - embed_dim=240, - num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="3conv", - ) - - pretrained_model = torch.load(filename) - model.load_state_dict(pretrained_model["params_ema"], strict=True) - if not cmd_opts.no_half: - model = model.half() - return model - - -def setup_model(dirname): - global model_path - global model_name - global cmd_path - if not os.path.exists(model_path): - os.makedirs(model_path) - cmd_path = dirname - model_file = "" - try: - models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path) - - if len(models) != 0: - model_file = models[0] - name = modelloader.friendly_name(model_file) - else: - # Add the "default" model if none are found. - model_file = model_url - name = model_name +class UpscalerSwinIR(Upscaler): + def __init__(self, dirname): + self.name = "SwinIR" + self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \ + "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ + "-L_x4_GAN.pth " + self.model_name = "SwinIR 4x" + self.model_path = os.path.join(models_path, self.name) + self.user_path = dirname + super().__init__() + scalers = [] + model_files = self.find_models(ext_filter=[".pt", ".pth"]) + for model in model_files: + if "http" in model: + name = self.model_name + else: + name = modelloader.friendly_name(model) + model_data = UpscalerData(name, model, self) + scalers.append(model_data) + self.scalers = scalers + + def do_upscale(self, img, model_file): + model = self.load_model(model_file) + if model is None: + return img + model = model.to(device) + img = upscale(img, model) + try: + torch.cuda.empty_cache() + except: + pass + return img - modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name)) - except Exception: - print(f"Error loading SwinIR model: {model_file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + def load_model(self, path, scale=4): + if "http" in path: + dl_name = "%s%s" % (self.name.replace(" ", "_"), ".pth") + filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True) + else: + filename = path + if filename is None or not os.path.exists(filename): + return None + model = net( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], + embed_dim=240, + num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="3conv", + ) + + pretrained_model = torch.load(filename) + model.load_state_dict(pretrained_model["params_ema"], strict=True) + if not cmd_opts.no_half: + model = model.half() + return model def upscale( - img, - model, - tile=opts.SWIN_tile, - tile_overlap=opts.SWIN_tile_overlap, - window_size=8, - scale=4, + img, + model, + tile=opts.SWIN_tile, + tile_overlap=opts.SWIN_tile_overlap, + window_size=8, + scale=4, ): img = np.array(img) img = img[:, :, ::-1] @@ -125,34 +124,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale): for h_idx in h_idx_list: for w_idx in w_idx_list: - in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile] + in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile] out_patch = model(in_patch) out_patch_mask = torch.ones_like(out_patch) E[ - ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf ].add_(out_patch) W[ - ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf + ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf ].add_(out_patch_mask) output = E.div_(W) return output - - -class UpscalerSwin(modules.images.Upscaler): - def __init__(self, filename, title): - self.name = title - self.filename = filename - - def do_upscale(self, img): - model = load_model(self.filename) - if model is None: - return img - model = model.to(device) - img = upscale(img, model) - try: - torch.cuda.empty_cache() - except: - pass - return img
\ No newline at end of file diff --git a/modules/upscaler.py b/modules/upscaler.py new file mode 100644 index 00000000..d698282f --- /dev/null +++ b/modules/upscaler.py @@ -0,0 +1,121 @@ +import os +from abc import abstractmethod + +import PIL +import numpy as np +import torch +from PIL import Image + +import modules.shared +from modules import modelloader, shared + +LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) +from modules.paths import models_path + + +class Upscaler: + name = None + model_path = None + model_name = None + model_url = None + enable = True + filter = None + model = None + user_path = None + scalers: [] + tile = True + + def __init__(self, create_dirs=False): + self.mod_pad_h = None + self.tile_size = modules.shared.opts.ESRGAN_tile + self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap + self.device = modules.shared.device + self.img = None + self.output = None + self.scale = 1 + self.half = not modules.shared.cmd_opts.no_half + self.pre_pad = 0 + self.mod_scale = None + if self.name is not None and create_dirs: + self.model_path = os.path.join(models_path, self.name) + if not os.path.exists(self.model_path): + os.makedirs(self.model_path) + + try: + import cv2 + self.can_tile = True + except: + pass + + @abstractmethod + def do_upscale(self, img: PIL.Image, selected_model: str): + return img + + def upscale(self, img: PIL.Image, scale: int, selected_model: str = None): + self.scale = scale + dest_w = img.width * scale + dest_h = img.height * scale + for i in range(3): + if img.width >= dest_w and img.height >= dest_h: + break + img = self.do_upscale(img, selected_model) + if img.width != dest_w or img.height != dest_h: + img = img.resize(dest_w, dest_h, resample=LANCZOS) + + return img + + @abstractmethod + def load_model(self, path: str): + pass + + def find_models(self, ext_filter=None) -> list: + return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path) + + def update_status(self, prompt): + print(f"\nextras: {prompt}", file=shared.progress_print_out) + + +class UpscalerData: + name = None + data_path = None + scale: int = 4 + scaler: Upscaler = None + model: None + + def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None): + self.name = name + self.data_path = path + self.scaler = upscaler + self.scale = scale + self.model = model + + +class UpscalerNone(Upscaler): + name = "None" + scalers = [] + + def load_model(self, path): + pass + + def do_upscale(self, img, selected_model=None): + return img + + def __init__(self, dirname=None): + super().__init__(False) + self.scalers = [UpscalerData("None", None, self)] + + +class UpscalerLanczos(Upscaler): + scalers = [] + + def do_upscale(self, img, selected_model=None): + return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS) + + def load_model(self, _): + pass + + def __init__(self, dirname=None): + super().__init__(False) + self.name = "Lanczos" + self.scalers = [UpscalerData("Lanczos", None, self)] + @@ -1,9 +1,10 @@ import os
import signal
import threading
-
+import modules.paths
import modules.codeformer_model as codeformer
import modules.esrgan_model as esrgan
+import modules.bsrgan_model as bsrgan
import modules.extras
import modules.face_restoration
import modules.gfpgan_model as gfpgan
@@ -27,11 +28,7 @@ modules.sd_models.setup_model(cmd_opts.stablediffusion_models_path) codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
-
-esrgan.setup_model(cmd_opts.esrgan_models_path)
-swinir.setup_model(cmd_opts.swinir_models_path)
-realesrgan.setup_model(cmd_opts.realesrgan_models_path)
-ldsr.setup_model(cmd_opts.ldsr_models_path)
+modelloader.load_upscalers()
queue_lock = threading.Lock()
|