aboutsummaryrefslogtreecommitdiffstats
path: root/modules/modelloader.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-01-04 05:40:16 +0000
committerbrkirch <brkirch@users.noreply.github.com>2023-01-06 05:14:20 +0000
commitf6ab5a39d762a7791573d1c52ae5a3024b10e8ed (patch)
treec3958d77a6dae42457b571dbe0f1efec7ce45dd2 /modules/modelloader.py
parentd782a95967c9eea753df3333cd1954b6ec73eba0 (diff)
parent3e22e294135ed0327ce9d9738655ff03c53df3c0 (diff)
downloadstable-diffusion-webui-gfx803-f6ab5a39d762a7791573d1c52ae5a3024b10e8ed.tar.gz
stable-diffusion-webui-gfx803-f6ab5a39d762a7791573d1c52ae5a3024b10e8ed.tar.bz2
stable-diffusion-webui-gfx803-f6ab5a39d762a7791573d1c52ae5a3024b10e8ed.zip
Merge branch 'AUTOMATIC1111:master' into sub-quad_attn_opt
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r--modules/modelloader.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index e647f6fa..6a1a7ac8 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
pass
+builtin_upscaler_classes = []
+forbidden_upscaler_classes = set()
+
+
+def list_builtin_upscalers():
+ load_upscalers()
+
+ builtin_upscaler_classes.clear()
+ builtin_upscaler_classes.extend(Upscaler.__subclasses__())
+
+
+def forbid_loaded_nonbuiltin_upscalers():
+ for cls in Upscaler.__subclasses__():
+ if cls not in builtin_upscaler_classes:
+ forbidden_upscaler_classes.add(cls)
+
+
def load_upscalers():
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
@@ -139,6 +156,9 @@ def load_upscalers():
datas = []
commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
+ if cls in forbidden_upscaler_classes:
+ continue
+
name = cls.__name__
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
scaler = cls(commandline_options.get(cmd_name, None))