diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-16 09:11:01 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-16 09:11:01 +0000 |
commit | eaba3d7349c6f0e151be66ade3fdc848d693a10d (patch) | |
tree | 45e4741b441756a2ddadfed70f4fbda4f6cb3022 | |
parent | 57e59c14c8a13a99d6422597d27d92ad10a51ca1 (diff) | |
download | stable-diffusion-webui-gfx803-eaba3d7349c6f0e151be66ade3fdc848d693a10d.tar.gz stable-diffusion-webui-gfx803-eaba3d7349c6f0e151be66ade3fdc848d693a10d.tar.bz2 stable-diffusion-webui-gfx803-eaba3d7349c6f0e151be66ade3fdc848d693a10d.zip |
send weights to target device instead of CPU memory
-rw-r--r-- | modules/sd_disable_initialization.py | 24 | ||||
-rw-r--r-- | modules/sd_models.py | 17 |
2 files changed, 31 insertions, 10 deletions
diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 719eeb93..8863107a 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -155,10 +155,16 @@ class LoadStateDictOnMeta(ReplaceHelper): ```
"""
- def __init__(self, state_dict, device):
+ def __init__(self, state_dict, device, weight_dtype_conversion=None):
super().__init__()
self.state_dict = state_dict
self.device = device
+ self.weight_dtype_conversion = weight_dtype_conversion or {}
+ self.default_dtype = self.weight_dtype_conversion.get('')
+
+ def get_weight_dtype(self, key):
+ key_first_term, _ = key.split('.', 1)
+ return self.weight_dtype_conversion.get(key_first_term, self.default_dtype)
def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
@@ -167,24 +173,24 @@ class LoadStateDictOnMeta(ReplaceHelper): sd = self.state_dict
device = self.device
- def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
+ def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = []
- for name, param in self._parameters.items():
+ for name, param in module._parameters.items():
if param is None:
continue
key = prefix + name
sd_param = sd.pop(key, None)
if sd_param is not None:
- state_dict[key] = sd_param
+ state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
used_param_keys.append(key)
if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype
- self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
+ module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)
- for name in self._buffers:
+ for name in module._buffers:
key = prefix + name
sd_param = sd.pop(key, None)
@@ -192,12 +198,12 @@ class LoadStateDictOnMeta(ReplaceHelper): state_dict[key] = sd_param
used_param_keys.append(key)
- original(self, state_dict, prefix, *args, **kwargs)
+ original(module, state_dict, prefix, *args, **kwargs)
for key in used_param_keys:
state_dict.pop(key, None)
- def load_state_dict(original, self, state_dict, strict=True):
+ def load_state_dict(original, module, state_dict, strict=True):
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.
@@ -212,7 +218,7 @@ class LoadStateDictOnMeta(ReplaceHelper): if state_dict == sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}
- original(self, state_dict, strict=strict)
+ original(module, state_dict, strict=strict)
module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd6..f912fe16 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -518,6 +518,13 @@ def send_model_to_cpu(m): devices.torch_gc()
+def model_target_device():
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ return devices.cpu
+ else:
+ return devices.device
+
+
def send_model_to_device(m):
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
@@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model")
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ if shared.cmd_opts.no_half:
+ weight_dtype_conversion = None
+ else:
+ weight_dtype_conversion = {
+ 'first_stage_model': None,
+ '': torch.float16,
+ }
+
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
|