aboutsummaryrefslogtreecommitdiffstats
path: root/modules/esrgan_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/esrgan_model.py')
-rw-r--r--modules/esrgan_model.py56
1 files changed, 41 insertions, 15 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py
index 7f3baf31..dd0ee629 100644
--- a/modules/esrgan_model.py
+++ b/modules/esrgan_model.py
@@ -5,15 +5,35 @@ import traceback
import numpy as np
import torch
from PIL import Image
+from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch
+import modules.images
from modules import shared
-from modules.shared import opts
+from modules import shared, modelloader
from modules.devices import has_mps
-import modules.images
-
+from modules.paths import models_path
+from modules.shared import opts
-def load_model(filename):
+model_dir = "ESRGAN"
+model_path = os.path.join(models_path, model_dir)
+model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
+model_name = "ESRGAN_x4.pth"
+
+
+def load_model(path: str, name: str):
+ global model_path
+ global model_url
+ global model_dir
+ global model_name
+ if "http" in path:
+ filename = load_file_from_url(url=model_url, model_dir=model_path, file_name=model_name, progress=True)
+ else:
+ filename = path
+ if not os.path.exists(filename) or filename is None:
+ print("Unable to load %s from %s" % (model_dir, filename))
+ return None
+ print("Loading %s from %s" % (model_dir, filename))
# this code is adapted from https://github.com/xinntao/ESRGAN
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
@@ -118,24 +138,30 @@ def esrgan_upscale(model, img):
class UpscalerESRGAN(modules.images.Upscaler):
def __init__(self, filename, title):
self.name = title
- self.model = load_model(filename)
+ self.filename = filename
def do_upscale(self, img):
- model = self.model.to(shared.device)
+ model = load_model(self.filename, self.name)
+ if model is None:
+ return img
+ model.to(shared.device)
img = esrgan_upscale(model, img)
return img
-def load_models(dirname):
- for file in os.listdir(dirname):
- path = os.path.join(dirname, file)
- model_name, extension = os.path.splitext(file)
-
- if extension != '.pt' and extension != '.pth':
- continue
+def setup_model(dirname):
+ global model_path
+ global model_name
+ if not os.path.exists(model_path):
+ os.makedirs(model_path)
+ model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
+ if len(model_paths) == 0:
+ modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
+ for file in model_paths:
+ name = modelloader.friendly_name(file)
try:
- modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
+ modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
except Exception:
- print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
+ print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)