From 637815632f9f362c9959e53139d37e88ea9ace6f Mon Sep 17 00:00:00 2001 From: Tim Patton <38817597+pattontim@users.noreply.github.com> Date: Sun, 20 Nov 2022 13:36:05 -0500 Subject: Generalize SD torch load/save to implement safetensor merging compat --- modules/sd_models.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 4ccdf30b..2f8c2c48 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,7 @@ import sys import gc from collections import namedtuple import torch -from safetensors.torch import load_file +from safetensors.torch import load_file, save_file import re from omegaconf import OmegaConf @@ -143,6 +143,22 @@ def transform_checkpoint_dict_key(k): return k +def torch_load(model_filename, model_info, map_override=None): + map_override=shared.weight_load_location if not map_override else map_override + if(checkpoint_types[model_info.exttype] == 'safetensors'): + # safely load weights + # TODO: safetensors supports zero copy fast load to gpu, see issue #684 + return load_file(model_filename, device=map_override) + else: + return torch.load(model_filename, map_location=map_override) + +def torch_save(model, output_filename): + basename, exttype = os.path.splitext(output_filename) + if(checkpoint_types[exttype] == 'safetensors'): + # [===== >] Reticulating brines... + save_file(model, output_filename, metadata={"format": "pt"}) + else: + torch.save(model, output_filename) def get_state_dict_from_checkpoint(pl_sd): if "state_dict" in pl_sd: @@ -175,12 +191,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - if(checkpoint_types[checkpoint_info.exttype] == 'safetensors'): - # safely load weights - # TODO: safetensors supports zero copy fast load to gpu, see issue #684 - pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) - else: - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + pl_sd = torch_load(checkpoint_file, checkpoint_info) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") -- cgit v1.2.3