diff options
author | d8ahazard <d8ahazard@gmail.com> | 2022-09-26 14:29:50 +0000 |
---|---|---|
committer | d8ahazard <d8ahazard@gmail.com> | 2022-09-26 14:29:50 +0000 |
commit | 740070ea9cdb254209f66417418f2a4af8b099d6 (patch) | |
tree | 52896a6159b706024af9520c855c10091162372c /modules/realesrgan_model.py | |
parent | bfb7f15d46048f27338eeac3a591a5943d03c5f1 (diff) | |
download | stable-diffusion-webui-gfx803-740070ea9cdb254209f66417418f2a4af8b099d6.tar.gz stable-diffusion-webui-gfx803-740070ea9cdb254209f66417418f2a4af8b099d6.tar.bz2 stable-diffusion-webui-gfx803-740070ea9cdb254209f66417418f2a4af8b099d6.zip |
Re-implement universal model loading
Diffstat (limited to 'modules/realesrgan_model.py')
-rw-r--r-- | modules/realesrgan_model.py | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index c32d6c4c..458bf678 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,14 +1,20 @@ +import os
import sys
import traceback
from collections import namedtuple
import numpy as np
from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
import modules.images
+from modules.paths import models_path
from modules.shared import cmd_opts, opts
+model_dir = "RealESRGAN"
+model_path = os.path.join(models_path, model_dir)
+cmd_dir = None
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
realesrgan_models = []
have_realesrgan = False
@@ -17,7 +23,6 @@ have_realesrgan = False def get_realesrgan_models():
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
- from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [
RealesrganModelInfo(
@@ -59,7 +64,7 @@ def get_realesrgan_models(): ]
return models
except Exception as e:
- print("Error makeing Real-ESRGAN midels list:", file=sys.stderr)
+ print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@@ -73,10 +78,15 @@ class UpscalerRealESRGAN(modules.images.Upscaler): return upscale_with_realesrgan(img, self.upscaling, self.model_index)
-def setup_realesrgan():
+def setup_model(dirname):
+ global model_path
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+
global realesrgan_models
global have_realesrgan
-
+ if model_path != dirname:
+ model_path = dirname
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
@@ -104,6 +114,11 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index) info = realesrgan_models[RealESRGAN_model_index]
model = info.model()
+ model_file = load_file_from_url(url=info.location, model_dir=model_path, progress=True)
+ if not os.path.exists(model_file):
+ print("Unable to load RealESRGAN model: %s" % info.name)
+ return image
+
upsampler = RealESRGANer(
scale=info.netscale,
model_path=info.location,
|