aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorMuhammad Rizqi Nur <rizqinur2010@gmail.com>2022-11-01 17:25:08 +0000
committerGitHub <noreply@github.com>2022-11-01 17:25:08 +0000
commitf8c6468d42e1202f7aeaeb961ab003aa0a2daf99 (patch)
treea2542ce9bd8bba1e8aa93acd510a12ca8a0b344f /modules/sd_models.py
parent7c8c3715f552378cf81ad28f26fad92b37bd153d (diff)
parent198a1ffcfc963a3d74674fad560e87dbebf7949f (diff)
downloadstable-diffusion-webui-gfx803-f8c6468d42e1202f7aeaeb961ab003aa0a2daf99.tar.gz
stable-diffusion-webui-gfx803-f8c6468d42e1202f7aeaeb961ab003aa0a2daf99.tar.bz2
stable-diffusion-webui-gfx803-f8c6468d42e1202f7aeaeb961ab003aa0a2daf99.zip
Merge branch 'master' into vae-picker
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py14
1 files changed, 13 insertions, 1 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 850f7b7b..6ab85b65 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -1,6 +1,7 @@
import collections
import os.path
import sys
+import gc
from collections import namedtuple
import torch
import re
@@ -214,6 +215,12 @@ def load_model(checkpoint_info=None):
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
+ if shared.sd_model:
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+ shared.sd_model = None
+ gc.collect()
+ devices.torch_gc()
+
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
@@ -227,6 +234,7 @@ def load_model(checkpoint_info=None):
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
+
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
@@ -246,14 +254,18 @@ def load_model(checkpoint_info=None):
return sd_model
-def reload_model_weights(sd_model, info=None, force=False):
+def reload_model_weights(sd_model=None, info=None, force=False):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
+
+ if not sd_model:
+ sd_model = shared.sd_model
if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
return shared.sd_model