aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/LDSR/ldsr_model_arch.py
diff options
context:
space:
mode:
authorunknown <mcgpapu@gmail.com>2022-12-25 08:03:55 +0000
committerunknown <mcgpapu@gmail.com>2022-12-25 08:03:55 +0000
commit876da1259965130603f2a7fea505cfa0fce09e2e (patch)
treeccb8b89d64480a4bd224b311702ffeb13b8fe754 /extensions-builtin/LDSR/ldsr_model_arch.py
parentd6fdfde9d70f1b86b696240fb0a0c8f2a4d024f6 (diff)
parentc6f347b81f584b6c0d44af7a209983284dbb52d2 (diff)
downloadstable-diffusion-webui-gfx803-876da1259965130603f2a7fea505cfa0fce09e2e.tar.gz
stable-diffusion-webui-gfx803-876da1259965130603f2a7fea505cfa0fce09e2e.tar.bz2
stable-diffusion-webui-gfx803-876da1259965130603f2a7fea505cfa0fce09e2e.zip
Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui
Diffstat (limited to 'extensions-builtin/LDSR/ldsr_model_arch.py')
-rw-r--r--extensions-builtin/LDSR/ldsr_model_arch.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py
index 8b048ae0..0ad49f4e 100644
--- a/extensions-builtin/LDSR/ldsr_model_arch.py
+++ b/extensions-builtin/LDSR/ldsr_model_arch.py
@@ -1,3 +1,4 @@
+import os
import gc
import time
import warnings
@@ -8,6 +9,7 @@ import torchvision
from PIL import Image
from einops import rearrange, repeat
from omegaconf import OmegaConf
+import safetensors.torch
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
@@ -24,12 +26,16 @@ class LDSR:
global cached_ldsr_model
if shared.opts.ldsr_cached and cached_ldsr_model is not None:
- print(f"Loading model from cache")
+ print("Loading model from cache")
model: torch.nn.Module = cached_ldsr_model
else:
print(f"Loading model from {self.modelPath}")
- pl_sd = torch.load(self.modelPath, map_location="cpu")
- sd = pl_sd["state_dict"]
+ _, extension = os.path.splitext(self.modelPath)
+ if extension.lower() == ".safetensors":
+ pl_sd = safetensors.torch.load_file(self.modelPath, device="cpu")
+ else:
+ pl_sd = torch.load(self.modelPath, map_location="cpu")
+ sd = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
config = OmegaConf.load(self.yamlPath)
config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
model: torch.nn.Module = instantiate_from_config(config.model)