aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-06 14:01:07 +0000
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-06 14:01:07 +0000
commitf1975b0213f5be400889ec04b3891d1cb571fe20 (patch)
tree874e4bd221209a5197f1f578f907cdc28b33a6b7 /modules/sd_models.py
parent57e8a11d17a6646fdf551320f5f714fba752987a (diff)
downloadstable-diffusion-webui-gfx803-f1975b0213f5be400889ec04b3891d1cb571fe20.tar.gz
stable-diffusion-webui-gfx803-f1975b0213f5be400889ec04b3891d1cb571fe20.tar.bz2
stable-diffusion-webui-gfx803-f1975b0213f5be400889ec04b3891d1cb571fe20.zip
initial refiner support
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f6051604..981aa93d 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -289,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
return res
+class SkipWritingToConfig:
+ """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""
+
+ skip = False
+ previous = None
+
+ def __enter__(self):
+ self.previous = SkipWritingToConfig.skip
+ SkipWritingToConfig.skip = True
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ SkipWritingToConfig.skip = self.previous
+
+
def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
- shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
+ if not SkipWritingToConfig.skip:
+ shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)