diff options
author | wywywywy <wywywywy@gmail.com> | 2022-12-10 18:57:18 +0000 |
---|---|---|
committer | wywywywy <wywywywy@gmail.com> | 2022-12-10 18:57:18 +0000 |
commit | 8bcdd50461090a2dd238082b33f4c1423378ebbd (patch) | |
tree | 0c9f0f3ea522bb8c6914bc0af77ca570163481b2 /extensions-builtin/LDSR/ldsr_model_arch.py | |
parent | 685f9631b56ff8bd43bce24ff5ce0f9a0e9af490 (diff) | |
download | stable-diffusion-webui-gfx803-8bcdd50461090a2dd238082b33f4c1423378ebbd.tar.gz stable-diffusion-webui-gfx803-8bcdd50461090a2dd238082b33f4c1423378ebbd.tar.bz2 stable-diffusion-webui-gfx803-8bcdd50461090a2dd238082b33f4c1423378ebbd.zip |
Add safetensors support to LDSR
Diffstat (limited to 'extensions-builtin/LDSR/ldsr_model_arch.py')
-rw-r--r-- | extensions-builtin/LDSR/ldsr_model_arch.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index 8b048ae0..f5bd8ae4 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -1,3 +1,4 @@ +import os import gc import time import warnings @@ -8,6 +9,7 @@ import torchvision from PIL import Image from einops import rearrange, repeat from omegaconf import OmegaConf +import safetensors.torch from ldm.models.diffusion.ddim import DDIMSampler from ldm.util import instantiate_from_config, ismap @@ -28,8 +30,12 @@ class LDSR: model: torch.nn.Module = cached_ldsr_model else: print(f"Loading model from {self.modelPath}") - pl_sd = torch.load(self.modelPath, map_location="cpu") - sd = pl_sd["state_dict"] + _, extension = os.path.splitext(self.modelPath) + if extension.lower() == ".safetensors": + pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu") + else: + pl_sd = torch.load(self.modelPath, map_location="cpu") + sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd config = OmegaConf.load(self.yamlPath) config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1" model: torch.nn.Module = instantiate_from_config(config.model) |