aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNicolas Patry <patry.nicolas@protonmail.com>2022-12-21 12:45:58 +0000
committerNicolas Patry <patry.nicolas@protonmail.com>2022-12-27 10:27:19 +0000
commit5ba04f9ec050a66e918571f07e8863f157f05b44 (patch)
treefcb45835424d0e0f819d7490afa1648a26732d5f
parent4af3ca5393151d61363c30eef4965e694eeac15e (diff)
downloadstable-diffusion-webui-gfx803-5ba04f9ec050a66e918571f07e8863f157f05b44.tar.gz
stable-diffusion-webui-gfx803-5ba04f9ec050a66e918571f07e8863f157f05b44.tar.bz2
stable-diffusion-webui-gfx803-5ba04f9ec050a66e918571f07e8863f157f05b44.zip
Attempting to solve slow loads for `safetensors`.
Fixes #5893
-rw-r--r--modules/sd_models.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ecdd91c5..cd938656 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -168,7 +168,10 @@ def get_state_dict_from_checkpoint(pl_sd):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
+ device = map_location or shared.weight_load_location
+ if device is None:
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)