aboutsummaryrefslogtreecommitdiffstats
path: root/modules/swinir_model.py
diff options
context:
space:
mode:
authord8ahazard <d8ahazard@gmail.com>2022-09-26 14:29:50 +0000
committerd8ahazard <d8ahazard@gmail.com>2022-09-26 14:29:50 +0000
commit740070ea9cdb254209f66417418f2a4af8b099d6 (patch)
tree52896a6159b706024af9520c855c10091162372c /modules/swinir_model.py
parentbfb7f15d46048f27338eeac3a591a5943d03c5f1 (diff)
downloadstable-diffusion-webui-gfx803-740070ea9cdb254209f66417418f2a4af8b099d6.tar.gz
stable-diffusion-webui-gfx803-740070ea9cdb254209f66417418f2a4af8b099d6.tar.bz2
stable-diffusion-webui-gfx803-740070ea9cdb254209f66417418f2a4af8b099d6.zip
Re-implement universal model loading
Diffstat (limited to 'modules/swinir_model.py')
-rw-r--r--modules/swinir_model.py75
1 files changed, 55 insertions, 20 deletions
diff --git a/modules/swinir_model.py b/modules/swinir_model.py
index e86d0789..f515779e 100644
--- a/modules/swinir_model.py
+++ b/modules/swinir_model.py
@@ -1,21 +1,39 @@
+import contextlib
+import os
import sys
import traceback
-import cv2
-import os
-import contextlib
+
import numpy as np
-from PIL import Image
import torch
+from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
+
import modules.images
+from modules import modelloader
+from modules.paths import models_path
from modules.shared import cmd_opts, opts, device
-from modules.swinir_arch import SwinIR as net
+from modules.swinir_model_arch import SwinIR as net
+model_dir = "SwinIR"
+model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
+model_name = "SwinIR x4"
+model_path = os.path.join(models_path, model_dir)
+cmd_path = ""
precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
)
-def load_model(filename, scale=4):
+def load_model(path, scale=4):
+ global model_path
+ global model_name
+ if "http" in path:
+ dl_name = "%s%s" % (model_name.replace(" ", "_"), ".pth")
+ filename = load_file_from_url(url=path, model_dir=model_path, file_name=dl_name, progress=True)
+ else:
+ filename = path
+ if filename is None or not os.path.exists(filename):
+ return None
model = net(
upscale=scale,
in_chans=3,
@@ -37,19 +55,29 @@ def load_model(filename, scale=4):
return model
-def load_models(dirname):
- for file in os.listdir(dirname):
- path = os.path.join(dirname, file)
- model_name, extension = os.path.splitext(file)
+def setup_model(dirname):
+ global model_path
+ global model_name
+ global cmd_path
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+ cmd_path = dirname
+ model_file = ""
+ try:
+ models = modelloader.load_models(model_path, ext_filter=[".pt", ".pth"], command_path=cmd_path)
- if extension != ".pt" and extension != ".pth":
- continue
+ if len(models) != 0:
+ model_file = models[0]
+ name = modelloader.friendly_name(model_file)
+ else:
+ # Add the "default" model if none are found.
+ model_file = model_url
+ name = model_name
- try:
- modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
- except Exception:
- print(f"Error loading SwinIR model: {path}", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
+ modules.shared.sd_upscalers.append(UpscalerSwin(model_file, name))
+ except Exception:
+ print(f"Error loading SwinIR model: {model_file}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
def upscale(
@@ -115,9 +143,16 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
class UpscalerSwin(modules.images.Upscaler):
def __init__(self, filename, title):
self.name = title
- self.model = load_model(filename)
+ self.filename = filename
def do_upscale(self, img):
- model = self.model.to(device)
+ model = load_model(self.filename)
+ if model is None:
+ return img
+ model = model.to(device)
img = upscale(img, model)
- return img
+ try:
+ torch.cuda.empty_cache()
+ except:
+ pass
+ return img \ No newline at end of file