diff options
author | d8ahazard <d8ahazard@gmail.com> | 2022-09-26 14:29:50 +0000 |
---|---|---|
committer | d8ahazard <d8ahazard@gmail.com> | 2022-09-26 14:29:50 +0000 |
commit | 740070ea9cdb254209f66417418f2a4af8b099d6 (patch) | |
tree | 52896a6159b706024af9520c855c10091162372c /modules/swinir_model.py | |
parent | bfb7f15d46048f27338eeac3a591a5943d03c5f1 (diff) | |
download | stable-diffusion-webui-gfx803-740070ea9cdb254209f66417418f2a4af8b099d6.tar.gz stable-diffusion-webui-gfx803-740070ea9cdb254209f66417418f2a4af8b099d6.tar.bz2 stable-diffusion-webui-gfx803-740070ea9cdb254209f66417418f2a4af8b099d6.zip |
Re-implement universal model loading
Diffstat (limited to 'modules/swinir_model.py')
-rw-r--r-- | modules/swinir_model.py | 75 |
1 files changed, 55 insertions, 20 deletions
diff --git a/modules/swinir_model.py b/modules/swinir_model.py index e86d0789..f515779e 100644 --- a/modules/swinir_model.py +++ b/modules/swinir_model.py @@ -1,21 +1,39 @@ +import contextlib +import os import sys import traceback -import cv2 -import os -import contextlib + import numpy as np -from PIL import Image import torch +from PIL import Image +from basicsr.utils.download_util import load_file_from_url + import modules.images +from modules import modelloader +from modules.paths import models_path from modules.shared import cmd_opts, opts, device -from modules.swinir_arch import SwinIR as net +from modules.swinir_model_arch import SwinIR as net +model_dir = "SwinIR" +model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth" +model_name = "SwinIR x4" +model_path = os.path.join(models_path, model_dir) +cmd_path = "" precision_scope = ( torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext ) -def load_model(filename, scale=4): +def load_model(path, scale=4): + global model_path + global model_name + if "http" in path: + dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth") + filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True) + else: + filename = path + if filename is None or not os.path.exists(filename): + return None model = net( upscale=scale, in_chans=3, @@ -37,19 +55,29 @@ def load_model(filename, scale=4): return model -def load_models(dirname): - for file in os.listdir(dirname): - path = os.path.join(dirname, file) - model_name, extension = os.path.splitext(file) +def setup_model(dirname): + global model_path + global model_name + global cmd_path + if not os.path.exists(model_path): + os.makedirs(model_path) + cmd_path = dirname + model_file = "" + try: + models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path) - if extension != ".pt" and extension != ".pth": - continue + if len(models) != 0: + model_file = models[0] + name = modelloader.friendly_name(model_file) + else: + # Add the "default" model if none are found. + model_file = model_url + name = model_name - try: - modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name)) - except Exception: - print(f"Error loading SwinIR model: {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name)) + except Exception: + print(f"Error loading SwinIR model: {model_file}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) def upscale( @@ -115,9 +143,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale): class UpscalerSwin(modules.images.Upscaler): def __init__(self, filename, title): self.name = title - self.model = load_model(filename) + self.filename = filename def do_upscale(self, img): - model = self.model.to(device) + model = load_model(self.filename) + if model is None: + return img + model = model.to(device) img = upscale(img, model) - return img + try: + torch.cuda.empty_cache() + except: + pass + return img
\ No newline at end of file |