diff options
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 6dca4ddf..76a89e88 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -171,7 +171,10 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
_, extension = os.path.splitext(checkpoint_file)
if extension.lower() == ".safetensors":
- pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location)
+ device = map_location or shared.weight_load_location
+ if device is None:
+ device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
+ pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
@@ -305,9 +308,6 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.in_channels = 9
sd_config.model.params.finetune_keys = None
- # Create a "fake" config with a different name so that we know to unload it when switching models.
- checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml"))
-
if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
|