diff options
author | Tim Patton <38817597+pattontim@users.noreply.github.com> | 2022-11-22 15:13:07 +0000 |
---|---|---|
committer | Tim Patton <38817597+pattontim@users.noreply.github.com> | 2022-11-22 15:13:07 +0000 |
commit | ac90cf38c6b55d57d37923aa1fe86c7374e32d0b (patch) | |
tree | b58c7821989a1b0559cd8bbbb447adeea1955675 | |
parent | 210cb4c128afdd65fa998229a97d0694154983ea (diff) | |
download | stable-diffusion-webui-gfx803-ac90cf38c6b55d57d37923aa1fe86c7374e32d0b.tar.gz stable-diffusion-webui-gfx803-ac90cf38c6b55d57d37923aa1fe86c7374e32d0b.tar.bz2 stable-diffusion-webui-gfx803-ac90cf38c6b55d57d37923aa1fe86c7374e32d0b.zip |
safetensors optional for now
-rw-r--r-- | modules/sd_models.py | 9 | ||||
-rw-r--r-- | requirements.txt | 1 |
2 files changed, 8 insertions, 2 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 2bbb3bf5..75f7ab09 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,6 @@ import sys import gc
from collections import namedtuple
import torch
-from safetensors.torch import load_file, save_file
import re
from omegaconf import OmegaConf
@@ -149,6 +148,10 @@ def torch_load(model_filename, model_info, map_override=None): # safely load weights
# TODO: safetensors supports zero copy fast load to gpu, see issue #684.
# GPU only for now, see https://github.com/huggingface/safetensors/issues/95
+ try:
+ from safetensors.torch import load_file
+ except ImportError as e:
+ raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
return load_file(model_filename, device='cuda')
else:
return torch.load(model_filename, map_location=map_override)
@@ -157,6 +160,10 @@ def torch_save(model, output_filename): basename, exttype = os.path.splitext(output_filename)
if(checkpoint_types[exttype] == 'safetensors'):
# [===== >] Reticulating brines...
+ try:
+ from safetensors.torch import save_file
+ except ImportError as e:
+ raise ImportError(f"Export as safetensors selected, yet it is not installed, use `pip install safetensors`: {e}")
save_file(model, output_filename, metadata={"format": "pt"})
else:
torch.save(model, output_filename)
diff --git a/requirements.txt b/requirements.txt index f7de9f70..762db4f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,3 @@ kornia lark
inflection
GitPython
-safetensors
|