aboutsummaryrefslogtreecommitdiffstats
path: root/modules/swinir_model.py
diff options
context:
space:
mode:
authord8ahazard <d8ahazard@gmail.com>2022-09-29 22:46:23 +0000
committerd8ahazard <d8ahazard@gmail.com>2022-09-29 22:46:23 +0000
commit0dce0df1ee63b2f158805c1a1f1a3743cc4a104b (patch)
treedfcec33656d06835e71961b117b63e510cb9bff2 /modules/swinir_model.py
parent31ad536c331df14dd785bfd2a1f93f91a8f7839e (diff)
downloadstable-diffusion-webui-gfx803-0dce0df1ee63b2f158805c1a1f1a3743cc4a104b.tar.gz
stable-diffusion-webui-gfx803-0dce0df1ee63b2f158805c1a1f1a3743cc4a104b.tar.bz2
stable-diffusion-webui-gfx803-0dce0df1ee63b2f158805c1a1f1a3743cc4a104b.zip
Holy $hit.
Yep. Fix gfpgan_model_arch requirement(s). Add Upscaler base class, move from images. Add a lot of methods to Upscaler. Re-work all the child upscalers to be proper classes. Add BSRGAN scaler. Add ldsr_model_arch class, removing the dependency for another repo that just uses regular latent-diffusion stuff. Add one universal method that will always find and load new upscaler models without having to add new "setup_model" calls. Still need to add command line params, but that could probably be automated. Add a "self.scale" property to all Upscalers so the scalers themselves can do "things" in response to the requested upscaling size. Ensure LDSR doesn't get stuck in a longer loop of "upscale/downscale/upscale" as we try to reach the target upscale size. Add typehints for IDE sanity. PEP-8 improvements. Moar.
Diffstat (limited to 'modules/swinir_model.py')
-rw-r--r--modules/swinir_model.py157
1 files changed, 69 insertions, 88 deletions
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index f515779e..ea7b6301 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -1,92 +1,91 @@
import contextlib
import os
-import sys
-import traceback
import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
-import modules.images
from modules import modelloader
from modules.paths import models_path
from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net
+from modules.upscaler import Upscaler, UpscalerData
-model_dir = "SwinIR"
-model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
-model_name = "SwinIR x4"
-model_path = os.path.join(models_path, model_dir)
-cmd_path = ""
precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
)
-def load_model(path, scale=4):
- global model_path
- global model_name
- if "http" in path:
- dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth")
- filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True)
- else:
- filename = path
- if filename is None or not os.path.exists(filename):
- return None
- model = net(
- upscale=scale,
- in_chans=3,
- img_size=64,
- window_size=8,
- img_range=1.0,
- depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
- embed_dim=240,
- num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
- mlp_ratio=2,
- upsampler="nearest+conv",
- resi_connection="3conv",
- )
-
- pretrained_model = torch.load(filename)
- model.load_state_dict(pretrained_model["params_ema"], strict=True)
- if not cmd_opts.no_half:
- model = model.half()
- return model
-
-
-def setup_model(dirname):
- global model_path
- global model_name
- global cmd_path
- if not os.path.exists(model_path):
- os.makedirs(model_path)
- cmd_path = dirname
- model_file = ""
- try:
- models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path)
-
- if len(models) != 0:
- model_file = models[0]
- name = modelloader.friendly_name(model_file)
- else:
- # Add the "default" model if none are found.
- model_file = model_url
- name = model_name
+class UpscalerSwinIR(Upscaler):
+ def __init__(self, dirname):
+ self.name = "SwinIR"
+ self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
+ "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
+ "-L_x4_GAN.pth "
+ self.model_name = "SwinIR 4x"
+ self.model_path = os.path.join(models_path, self.name)
+ self.user_path = dirname
+ super().__init__()
+ scalers = []
+ model_files = self.find_models(ext_filter=[".pt", ".pth"])
+ for model in model_files:
+ if "http" in model:
+ name = self.model_name
+ else:
+ name = modelloader.friendly_name(model)
+ model_data = UpscalerData(name, model, self)
+ scalers.append(model_data)
+ self.scalers = scalers
+
+ def do_upscale(self, img, model_file):
+ model = self.load_model(model_file)
+ if model is None:
+ return img
+ model = model.to(device)
+ img = upscale(img, model)
+ try:
+ torch.cuda.empty_cache()
+ except:
+ pass
+ return img
- modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name))
- except Exception:
- print(f"Error loading SwinIR model: {model_file}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ def load_model(self, path, scale=4):
+ if "http" in path:
+ dl_name = "%s%s" % (self.name.replace(" ", "_"), ".pth")
+ filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
+ else:
+ filename = path
+ if filename is None or not os.path.exists(filename):
+ return None
+ model = net(
+ upscale=scale,
+ in_chans=3,
+ img_size=64,
+ window_size=8,
+ img_range=1.0,
+ depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
+ embed_dim=240,
+ num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
+ mlp_ratio=2,
+ upsampler="nearest+conv",
+ resi_connection="3conv",
+ )
+
+ pretrained_model = torch.load(filename)
+ model.load_state_dict(pretrained_model["params_ema"], strict=True)
+ if not cmd_opts.no_half:
+ model = model.half()
+ return model
def upscale(
- img,
- model,
- tile=opts.SWIN_tile,
- tile_overlap=opts.SWIN_tile_overlap,
- window_size=8,
- scale=4,
+ img,
+ model,
+ tile=opts.SWIN_tile,
+ tile_overlap=opts.SWIN_tile_overlap,
+ window_size=8,
+ scale=4,
):
img = np.array(img)
img = img[:, :, ::-1]
@@ -125,34 +124,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
for h_idx in h_idx_list:
for w_idx in w_idx_list:
- in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile]
+ in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
- ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
- ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
+ ..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
output = E.div_(W)
return output
-
-
-class UpscalerSwin(modules.images.Upscaler):
- def __init__(self, filename, title):
- self.name = title
- self.filename = filename
-
- def do_upscale(self, img):
- model = load_model(self.filename)
- if model is None:
- return img
- model = model.to(device)
- img = upscale(img, model)
- try:
- torch.cuda.empty_cache()
- except:
- pass
- return img \ No newline at end of file