aboutsummaryrefslogtreecommitdiffstats
path: root/modules/ldsr_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/ldsr_model.py')
-rw-r--r--modules/ldsr_model.py103
1 files changed, 37 insertions, 66 deletions
diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py
index 4f9b1657..969d1a0d 100644
--- a/modules/ldsr_model.py
+++ b/modules/ldsr_model.py
@@ -1,74 +1,45 @@
import os
import sys
import traceback
-from collections import namedtuple
-from modules import shared, images, modelloader, paths
-from modules.paths import models_path
-
-model_dir = "LDSR"
-model_path = os.path.join(models_path, model_dir)
-cmd_path = None
-model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
-yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
+from basicsr.utils.download_util import load_file_from_url
-LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
-
-ldsr_models = []
-have_ldsr = False
-LDSR_obj = None
+from modules.upscaler import Upscaler, UpscalerData
+from modules.ldsr_model_arch import LDSR
+from modules import shared
+from modules.paths import models_path
-class UpscalerLDSR(images.Upscaler):
- def __init__(self, steps):
- self.steps = steps
+class UpscalerLDSR(Upscaler):
+ def __init__(self, user_path):
self.name = "LDSR"
-
- def do_upscale(self, img):
- return upscale_with_ldsr(img)
-
-
-def setup_model(dirname):
- global cmd_path
- global model_path
- if not os.path.exists(model_path):
- os.makedirs(model_path)
- cmd_path = dirname
- shared.sd_upscalers.append(UpscalerLDSR(100))
-
-
-def prepare_ldsr():
- path = paths.paths.get("LDSR", None)
- if path is None:
- return
- global have_ldsr
- global LDSR_obj
- try:
- from LDSR import LDSR
- model_files = modelloader.load_models(model_path, model_url, cmd_path, dl_name="model.ckpt", ext_filter=[".ckpt"])
- yaml_files = modelloader.load_models(model_path, yaml_url, cmd_path, dl_name="project.yaml", ext_filter=[".yaml"])
- if len(model_files) != 0 and len(yaml_files) != 0:
- model_file = model_files[0]
- yaml_file = yaml_files[0]
- have_ldsr = True
- LDSR_obj = LDSR(model_file, yaml_file)
- else:
- return
-
- except Exception:
- print("Error importing LDSR:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- have_ldsr = False
-
-
-def upscale_with_ldsr(image):
- prepare_ldsr()
- if not have_ldsr or LDSR_obj is None:
- return image
-
- ddim_steps = shared.opts.ldsr_steps
- pre_scale = shared.opts.ldsr_pre_down
- post_scale = shared.opts.ldsr_post_down
-
- image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
- return image
+ self.model_path = os.path.join(models_path, self.name)
+ self.user_path = user_path
+ self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
+ self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
+ super().__init__()
+ scaler_data = UpscalerData("LDSR", None, self)
+ self.scalers = [scaler_data]
+
+ def load_model(self, path: str):
+ model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="model.pth", progress=True)
+ yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path,
+ file_name="project.yaml", progress=True)
+
+ try:
+ return LDSR(model, yaml)
+
+ except Exception:
+ print("Error importing LDSR:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ return None
+
+ def do_upscale(self, img, path):
+ ldsr = self.load_model(path)
+ if ldsr is None:
+ print("NO LDSR!")
+ return img
+ ddim_steps = shared.opts.ldsr_steps
+ pre_scale = shared.opts.ldsr_pre_down
+ return ldsr.super_resolution(img, ddim_steps, self.scale)