diff options
author | Leon Feng <523684+leon0707@users.noreply.github.com> | 2023-07-18 08:24:14 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-18 08:24:14 +0000 |
commit | a3730bd9becd2f1f5d209885b694b0dec178d110 (patch) | |
tree | 8ac9948d89606f7519df786f07f6ddb93c3d2720 /modules/esrgan_model.py | |
parent | d6668347c8b85b11b696ac56777cc396e34ee1f9 (diff) | |
parent | 871b8687a82bb2ca907d8a49c87aed7635b8fc33 (diff) | |
download | stable-diffusion-webui-gfx803-a3730bd9becd2f1f5d209885b694b0dec178d110.tar.gz stable-diffusion-webui-gfx803-a3730bd9becd2f1f5d209885b694b0dec178d110.tar.bz2 stable-diffusion-webui-gfx803-a3730bd9becd2f1f5d209885b694b0dec178d110.zip |
Merge branch 'dev' into fix-11805
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r-- | modules/esrgan_model.py | 23 |
1 files changed, 10 insertions, 13 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 2fced999..02a1727d 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,15 +1,13 @@ -import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict):
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler): scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler): self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = modelloader.load_file_from_url(
url=self.model_url,
model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
- progress=True,
)
else:
filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"Unable to load {self.model_path} from {filename}")
- return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
|