aboutsummaryrefslogtreecommitdiffstats
path: root/modules/esrgan_model.py
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-05-29 07:38:51 +0000
committerAarni Koskela <akx@iki.fi>2023-06-13 09:44:25 +0000
commitbf67a5dcf44c3dbd88d1913478d4e02477915f33 (patch)
treee4c1705c31bb76eef77781ef6df0628e2716f929 /modules/esrgan_model.py
parente3a973a68df3cfe13039dae33d19cf2c02a741e0 (diff)
downloadstable-diffusion-webui-gfx803-bf67a5dcf44c3dbd88d1913478d4e02477915f33.tar.gz
stable-diffusion-webui-gfx803-bf67a5dcf44c3dbd88d1913478d4e02477915f33.tar.bz2
stable-diffusion-webui-gfx803-bf67a5dcf44c3dbd88d1913478d4e02477915f33.zip
Upscaler.load_model: don't return None, just use exceptions
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r--modules/esrgan_model.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index a20e8d91..02a1727d 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -1,4 +1,4 @@
-import os
+import sys
import numpy as np
import torch
@@ -6,9 +6,8 @@ from PIL import Image
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict):
@@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler):
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
@@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler):
)
else:
filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"Unable to load {self.model_path} from {filename}")
- return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)