diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-09-30 06:35:58 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-09-30 06:35:58 +0000 |
commit | 25414bcd05ef8072ce97056039bdd99379b74be9 (patch) | |
tree | 1fddc7e0921c0626e0b6310b915ab9ad7c65fdcd /modules/ldsr_model.py | |
parent | f80c3696f63a181f720105559d42ee53453ed0eb (diff) | |
parent | 435fd2112aee9a0e61408ac56663e41beea1e446 (diff) | |
download | stable-diffusion-webui-gfx803-25414bcd05ef8072ce97056039bdd99379b74be9.tar.gz stable-diffusion-webui-gfx803-25414bcd05ef8072ce97056039bdd99379b74be9.tar.bz2 stable-diffusion-webui-gfx803-25414bcd05ef8072ce97056039bdd99379b74be9.zip |
Merge pull request #1109 from d8ahazard/ModelLoader
Model Loader, Fixes
Diffstat (limited to 'modules/ldsr_model.py')
-rw-r--r-- | modules/ldsr_model.py | 92 |
1 files changed, 35 insertions, 57 deletions
diff --git a/modules/ldsr_model.py b/modules/ldsr_model.py index 95e84659..969d1a0d 100644 --- a/modules/ldsr_model.py +++ b/modules/ldsr_model.py @@ -1,67 +1,45 @@ import os import sys import traceback -from collections import namedtuple from basicsr.utils.download_util import load_file_from_url -import modules.images +from modules.upscaler import Upscaler, UpscalerData +from modules.ldsr_model_arch import LDSR from modules import shared -from modules.paths import script_path +from modules.paths import models_path -LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"]) -ldsr_models = [] -have_ldsr = False -LDSR_obj = None - - -class UpscalerLDSR(modules.images.Upscaler): - def __init__(self, steps): - self.steps = steps +class UpscalerLDSR(Upscaler): + def __init__(self, user_path): self.name = "LDSR" - - def do_upscale(self, img): - return upscale_with_ldsr(img) - - -def add_lsdr(): - modules.shared.sd_upscalers.append(UpscalerLDSR(100)) - - -def setup_ldsr(): - path = modules.paths.paths.get("LDSR", None) - if path is None: - return - global have_ldsr - global LDSR_obj - try: - from LDSR import LDSR - model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" - yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" - repo_path = 'latent-diffusion/experiments/pretrained_models/' - model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path), - progress=True, file_name="model.chkpt") - yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path), - progress=True, file_name="project.yaml") - have_ldsr = True - LDSR_obj = LDSR(model_path, yaml_path) - - - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - have_ldsr = False - - -def upscale_with_ldsr(image): - setup_ldsr() - if not have_ldsr or LDSR_obj is None: - return image - - ddim_steps = shared.opts.ldsr_steps - pre_scale = shared.opts.ldsr_pre_down - post_scale = shared.opts.ldsr_post_down - - image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale) - return image + self.model_path = os.path.join(models_path, self.name) + self.user_path = user_path + self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" + self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" + super().__init__() + scaler_data = UpscalerData("LDSR", None, self) + self.scalers = [scaler_data] + + def load_model(self, path: str): + model = load_file_from_url(url=self.model_url, model_dir=self.model_path, + file_name="model.pth", progress=True) + yaml = load_file_from_url(url=self.model_url, model_dir=self.model_path, + file_name="project.yaml", progress=True) + + try: + return LDSR(model, yaml) + + except Exception: + print("Error importing LDSR:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + def do_upscale(self, img, path): + ldsr = self.load_model(path) + if ldsr is None: + print("NO LDSR!") + return img + ddim_steps = shared.opts.ldsr_steps + pre_scale = shared.opts.ldsr_pre_down + return ldsr.super_resolution(img, ddim_steps, self.scale) |