From eaba3d7349c6f0e151be66ade3fdc848d693a10d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 12:11:01 +0300 Subject: send weights to target device instead of CPU memory --- modules/sd_disable_initialization.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) (limited to 'modules/sd_disable_initialization.py') diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 719eeb93..8863107a 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper): ``` """ - def __init__(self, state_dict, device): + def __init__(self, state_dict, device, weight_dtype_conversion=None): super().__init__() self.state_dict = state_dict self.device = device + self.weight_dtype_conversion = weight_dtype_conversion or {} + self.default_dtype = self.weight_dtype_conversion.get('') + + def get_weight_dtype(self, key): + key_first_term, _ = key.split('.', 1) + return self.weight_dtype_conversion.get(key_first_term, self.default_dtype) def __enter__(self): if shared.cmd_opts.disable_model_loading_ram_optimization: @@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper): sd = self.state_dict device = self.device - def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): + def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): used_param_keys = [] - for name, param in self._parameters.items(): + for name, param in module._parameters.items(): if param is None: continue key = prefix + name sd_param = sd.pop(key, None) if sd_param is not None: - state_dict[key] = sd_param + state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) used_param_keys.append(key) if param.is_meta: dtype = sd_param.dtype if sd_param is not None else param.dtype - self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) - for name in self._buffers: + for name in module._buffers: key = prefix + name sd_param = sd.pop(key, None) @@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper): state_dict[key] = sd_param used_param_keys.append(key) - original(self, state_dict, prefix, *args, **kwargs) + original(module, state_dict, prefix, *args, **kwargs) for key in used_param_keys: state_dict.pop(key, None) - def load_state_dict(original, self, state_dict, strict=True): + def load_state_dict(original, module, state_dict, strict=True): """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. @@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper): if state_dict == sd: state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} - original(self, state_dict, strict=strict) + original(module, state_dict, strict=strict) module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) -- cgit v1.2.3