aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJosh Watzman <github@jwatzman.org>2022-10-27 20:59:16 +0000
committerJosh Watzman <github@jwatzman.org>2022-10-27 21:01:06 +0000
commitb50ff4f4e4d4d6bf31e222832d3fe4cfde4703c9 (patch)
tree318905a69b66bcefcce4eeb4832372b667fe23f1
parent737eb28faca8be2bb996ee0930ec77d1f7ebd939 (diff)
downloadstable-diffusion-webui-gfx803-b50ff4f4e4d4d6bf31e222832d3fe4cfde4703c9.tar.gz
stable-diffusion-webui-gfx803-b50ff4f4e4d4d6bf31e222832d3fe4cfde4703c9.tar.bz2
stable-diffusion-webui-gfx803-b50ff4f4e4d4d6bf31e222832d3fe4cfde4703c9.zip
Reduce peak memory usage when changing models
A few tweaks to reduce peak memory usage, the biggest being that if we aren't using the checkpoint cache, we shouldn't duplicate the model state dict just to immediately throw it away. On my machine with 16GB of RAM, this change means I can typically change models, whereas before it would typically OOM.
-rw-r--r--modules/sd_models.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index e697bb72..203e99a8 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -170,7 +170,9 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
- missing, extra = model.load_state_dict(sd, strict=False)
+ del pl_sd
+ model.load_state_dict(sd, strict=False)
+ del sd
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
@@ -194,9 +196,10 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.to(devices.dtype_vae)
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
- while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
- checkpoints_loaded.popitem(last=False) # LRU
+ if shared.opts.sd_checkpoint_cache > 0:
+ checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
+ checkpoints_loaded.popitem(last=False) # LRU
else:
print(f"Loading weights [{sd_model_hash}] from cache")
checkpoints_loaded.move_to_end(checkpoint_info)