aboutsummaryrefslogtreecommitdiffstats
path: root/webui.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-17 09:05:04 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-09-17 09:05:18 +0000
commit247f58a5e740a7bd3980815961425b778d77ec28 (patch)
tree56536dd5c7a078720e4d0cc71530280c70df3c47 /webui.py
parentb8be33dad13d4937c6ef8fbb49715d843c3dd586 (diff)
downloadstable-diffusion-webui-gfx803-247f58a5e740a7bd3980815961425b778d77ec28.tar.gz
stable-diffusion-webui-gfx803-247f58a5e740a7bd3980815961425b778d77ec28.tar.bz2
stable-diffusion-webui-gfx803-247f58a5e740a7bd3980815961425b778d77ec28.zip
add support for switching model checkpoints at runtime
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py61
1 files changed, 10 insertions, 51 deletions
diff --git a/webui.py b/webui.py
index add72123..ff8997db 100644
--- a/webui.py
+++ b/webui.py
@@ -3,13 +3,8 @@ import threading
from modules.paths import script_path
-import torch
-from omegaconf import OmegaConf
-
import signal
-from ldm.util import instantiate_from_config
-
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.ui
@@ -24,6 +19,7 @@ import modules.extras
import modules.lowvram
import modules.txt2img
import modules.img2img
+import modules.sd_models
modules.codeformer_model.setup_codeformer()
@@ -33,29 +29,17 @@ shared.face_restorers.append(modules.face_restoration.FaceRestoration())
esrgan.load_models(cmd_opts.esrgan_models_path)
realesrgan.setup_realesrgan()
+queue_lock = threading.Lock()
-def load_model_from_config(config, ckpt, verbose=False):
- print(f"Loading model [{shared.sd_model_hash}] from {ckpt}")
- pl_sd = torch.load(ckpt, map_location="cpu")
- if "global_step" in pl_sd:
- print(f"Global Step: {pl_sd['global_step']}")
- sd = pl_sd["state_dict"]
- model = instantiate_from_config(config.model)
- m, u = model.load_state_dict(sd, strict=False)
- if len(m) > 0 and verbose:
- print("missing keys:")
- print(m)
- if len(u) > 0 and verbose:
- print("unexpected keys:")
- print(u)
- if cmd_opts.opt_channelslast:
- model = model.to(memory_format=torch.channels_last)
- model.eval()
- return model
+def wrap_queued_call(func):
+ def f(*args, **kwargs):
+ with queue_lock:
+ res = func(*args, **kwargs)
+ return res
-queue_lock = threading.Lock()
+ return f
def wrap_gradio_gpu_call(func):
@@ -80,33 +64,8 @@ def wrap_gradio_gpu_call(func):
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
-try:
- # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
-
- from transformers import logging
-
- logging.set_verbosity_error()
-except Exception:
- pass
-
-with open(cmd_opts.ckpt, "rb") as file:
- import hashlib
- m = hashlib.sha256()
-
- file.seek(0x100000)
- m.update(file.read(0x10000))
- shared.sd_model_hash = m.hexdigest()[0:8]
-
-sd_config = OmegaConf.load(cmd_opts.config)
-shared.sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
-shared.sd_model = (shared.sd_model if cmd_opts.no_half else shared.sd_model.half())
-
-if cmd_opts.lowvram or cmd_opts.medvram:
- modules.lowvram.setup_for_low_vram(shared.sd_model, cmd_opts.medvram)
-else:
- shared.sd_model = shared.sd_model.to(shared.device)
-
-modules.sd_hijack.model_hijack.hijack(shared.sd_model)
+shared.sd_model = modules.sd_models.load_model()
+shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
def webui():