diff options
author | evshiron <evshiron@gmail.com> | 2022-11-02 04:31:33 +0000 |
---|---|---|
committer | evshiron <evshiron@gmail.com> | 2022-11-02 04:31:33 +0000 |
commit | 51e0a83969925e6b7b12b9b087d028c16ce04e3c (patch) | |
tree | bbccf573783cd2449e3774fe949c81dced83b6d4 /modules/sd_models.py | |
parent | 1a4ff2de6a835cd8cc1590bbc1a8dedb5ad37e5b (diff) | |
parent | 55688c48806f9383f3a56f6b9a0ab8fbf205edd2 (diff) | |
download | stable-diffusion-webui-gfx803-51e0a83969925e6b7b12b9b087d028c16ce04e3c.tar.gz stable-diffusion-webui-gfx803-51e0a83969925e6b7b12b9b087d028c16ce04e3c.tar.bz2 stable-diffusion-webui-gfx803-51e0a83969925e6b7b12b9b087d028c16ce04e3c.zip |
Merge branch 'master' into fix/progress-api
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..90007da3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,6 +1,7 @@ import collections
import os.path
import sys
+import gc
from collections import namedtuple
import torch
import re
@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None): if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
+ if shared.sd_model:
+ sd_hijack.model_hijack.undo_hijack(shared.sd_model)
+ shared.sd_model = None
+ gc.collect()
+ devices.torch_gc()
+
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None): checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
+
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None): return sd_model
-def reload_model_weights(sd_model, info=None):
+def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
+ if not sd_model:
+ sd_model = shared.sd_model
+
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
+ del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info)
return shared.sd_model
|