aboutsummaryrefslogtreecommitdiffstats
path: root/modules/esrgan_model.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-11 08:14:36 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-10-11 08:14:36 +0000
commit5de806184f6687e46cf936b92055146dc6cf2994 (patch)
treed84c2daa8798c3d2f8e99e17234a40065491182d /modules/esrgan_model.py
parent12c4d5c6b5bf9dd50d0601c36af4f99b65316d58 (diff)
parent948533950c9db5069a874d925fadd50bac00fdb5 (diff)
downloadstable-diffusion-webui-gfx803-5de806184f6687e46cf936b92055146dc6cf2994.tar.gz
stable-diffusion-webui-gfx803-5de806184f6687e46cf936b92055146dc6cf2994.tar.bz2
stable-diffusion-webui-gfx803-5de806184f6687e46cf936b92055146dc6cf2994.zip
Merge branch 'master' into hypernetwork-training
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r--modules/esrgan_model.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index d17e730f..46ad0da3 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -5,9 +5,8 @@ import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
-import modules.esrgam_model_arch as arch
+import modules.esrgan_model_arch as arch
from modules import shared, modelloader, images, devices
-from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
@@ -76,7 +75,6 @@ class UpscalerESRGAN(Upscaler):
self.model_name = "ESRGAN_4x"
self.scalers = []
self.user_path = dirname
- self.model_path = os.path.join(models_path, self.name)
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
@@ -111,7 +109,7 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
- pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
+ pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net)