diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-04 11:53:03 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-04 11:53:03 +0000 |
commit | 68fbf4558f9fbf36ca58f92f1fc4681f0ebdf735 (patch) | |
tree | 8ae0070bedc9965dcecedeba568463849122dc11 | |
parent | c4796bcc679c145b9cd53011c2a45c95b0ddabfa (diff) | |
parent | 5a523d13050a5ede43c473767f29dfe2e391136a (diff) | |
download | stable-diffusion-webui-gfx803-68fbf4558f9fbf36ca58f92f1fc4681f0ebdf735.tar.gz stable-diffusion-webui-gfx803-68fbf4558f9fbf36ca58f92f1fc4681f0ebdf735.tar.bz2 stable-diffusion-webui-gfx803-68fbf4558f9fbf36ca58f92f1fc4681f0ebdf735.zip |
Merge remote-tracking branch 'Narsil/fix_safetensors_load_speed'
-rw-r--r-- | modules/sd_models.py | 5 | ||||
-rw-r--r-- | requirements_versions.txt | 2 |
2 files changed, 5 insertions, 2 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index a568823d..ee918f24 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -171,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
+ device = map_location or shared.weight_load_location
+ if device is None:
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
diff --git a/requirements_versions.txt b/requirements_versions.txt index 836523ba..975102d9 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -26,5 +26,5 @@ lark==1.1.2 inflection==0.5.1
GitPython==3.1.27
torchsde==0.2.5
-safetensors==0.2.5
+safetensors==0.2.7
httpcore<=0.15
|