aboutsummaryrefslogtreecommitdiffstats
path: root/modules/hat_model.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2024-03-02 04:03:13 +0000
committerAUTOMATIC1111 <16777216c@gmail.com>2024-03-02 04:03:13 +0000
commitbef51aed032c0aaa5cfd80445bc4cf0d85b408b5 (patch)
tree42957c454a4ac8d98488f19811b60359d05d88ba /modules/hat_model.py
parentcf2772fab0af5573da775e7437e6acdca424f26e (diff)
parent13984857890401e8605a3e53bd671e900a18d73f (diff)
downloadstable-diffusion-webui-gfx803-bef51aed032c0aaa5cfd80445bc4cf0d85b408b5.tar.gz
stable-diffusion-webui-gfx803-bef51aed032c0aaa5cfd80445bc4cf0d85b408b5.tar.bz2
stable-diffusion-webui-gfx803-bef51aed032c0aaa5cfd80445bc4cf0d85b408b5.zip
Merge branch 'release_candidate'
Diffstat (limited to 'modules/hat_model.py')
-rw-r--r--modules/hat_model.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/modules/hat_model.py b/modules/hat_model.py
new file mode 100644
index 00000000..7f2abb41
--- /dev/null
+++ b/modules/hat_model.py
@@ -0,0 +1,43 @@
+import os
+import sys
+
+from modules import modelloader, devices
+from modules.shared import opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
+
+
+class UpscalerHAT(Upscaler):
+ def __init__(self, dirname):
+ self.name = "HAT"
+ self.scalers = []
+ self.user_path = dirname
+ super().__init__()
+ for file in self.find_models(ext_filter=[".pt", ".pth"]):
+ name = modelloader.friendly_name(file)
+ scale = 4 # TODO: scale might not be 4, but we can't know without loading the model
+ scaler_data = UpscalerData(name, file, upscaler=self, scale=scale)
+ self.scalers.append(scaler_data)
+
+ def do_upscale(self, img, selected_model):
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr)
+ return img
+ model.to(devices.device_esrgan) # TODO: should probably be device_hat
+ return upscale_with_model(
+ model,
+ img,
+ tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile
+ tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap
+ )
+
+ def load_model(self, path: str):
+ if not os.path.isfile(path):
+ raise FileNotFoundError(f"Model file {path} not found")
+ return modelloader.load_spandrel_model(
+ path,
+ device=devices.device_esrgan, # TODO: should probably be device_hat
+ expected_architecture='HAT',
+ )