diff options
author | papuSpartan <30642826+papuSpartan@users.noreply.github.com> | 2022-10-21 18:53:32 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-21 18:53:32 +0000 |
commit | 4a9ff0891abc413031b44926372f611513b4810f (patch) | |
tree | f7622454a45669b13bce691e312f2c9dcfd9fb8a /modules/sd_models.py | |
parent | a3b047b7c74dc6ca07f40aee778997fc1889d72f (diff) | |
parent | f49c08ea566385db339c6628f65c3a121033f67c (diff) | |
download | stable-diffusion-webui-gfx803-4a9ff0891abc413031b44926372f611513b4810f.tar.gz stable-diffusion-webui-gfx803-4a9ff0891abc413031b44926372f611513b4810f.tar.bz2 stable-diffusion-webui-gfx803-4a9ff0891abc413031b44926372f611513b4810f.zip |
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index eae22e87..d99dbce8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config from modules import shared, modelloader, devices
from modules.paths import models_path
+from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
@@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict() try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
- from transformers import logging
+ from transformers import logging, CLIPModel
logging.set_verbosity_error()
except Exception:
@@ -154,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd
+vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
+
+
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
@@ -185,7 +189,7 @@ def load_model_weights(model, checkpoint_info): if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
+ vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
@@ -203,14 +207,26 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info
-def load_model():
+def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
- checkpoint_info = select_checkpoint()
+ checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config)
+
+ if should_hijack_inpainting(checkpoint_info):
+ # Hardcoded config for now...
+ sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
+ sd_config.model.params.use_ema = False
+ sd_config.model.params.conditioning_key = "hybrid"
+ sd_config.model.params.unet_config.params.in_channels = 9
+
+ # Create a "fake" config with a different name so that we know to unload it when switching models.
+ checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
+
+ do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
@@ -234,9 +250,9 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
- if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
+ if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
- shared.sd_model = load_model()
+ shared.sd_model = load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
|