diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-11-27 11:46:40 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-11-27 11:46:40 +0000 |
commit | 6074175faa751dde933aa8e15cd687ca4e4b4a23 (patch) | |
tree | 1bd98e8d763b5c55669a8726aa9cdb2decc98b2f | |
parent | f108782e30369dedfc66f22d21c2b72c77941de7 (diff) | |
download | stable-diffusion-webui-gfx803-6074175faa751dde933aa8e15cd687ca4e4b4a23.tar.gz stable-diffusion-webui-gfx803-6074175faa751dde933aa8e15cd687ca4e4b4a23.tar.bz2 stable-diffusion-webui-gfx803-6074175faa751dde933aa8e15cd687ca4e4b4a23.zip |
add safetensors to requirements
-rw-r--r-- | modules/sd_models.py | 11 | ||||
-rw-r--r-- | requirements.txt | 1 | ||||
-rw-r--r-- | requirements_versions.txt | 1 |
3 files changed, 7 insertions, 6 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index ae36841a..77236480 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -5,6 +5,7 @@ import gc from collections import namedtuple
import torch
import re
+import safetensors.torch
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
@@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
- if checkpoint_file.endswith(".safetensors"):
- try:
- from safetensors.torch import load_file
- except ImportError as e:
- raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
- pl_sd = load_file(checkpoint_file, device=shared.weight_load_location)
+ _, extension = os.path.splitext(checkpoint_file)
+ if extension.lower() == ".safetensors":
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location)
else:
pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
+
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
diff --git a/requirements.txt b/requirements.txt index e4e5ec64..5f3d9623 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,4 @@ lark inflection
GitPython
torchsde
+safetensors
diff --git a/requirements_versions.txt b/requirements_versions.txt index 8d557fe3..035fa82f 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -26,3 +26,4 @@ lark==1.1.2 inflection==0.5.1
GitPython==3.1.27
torchsde==0.2.5
+safetensors==0.2.5
|