From eaba3d7349c6f0e151be66ade3fdc848d693a10d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 12:11:01 +0300 Subject: send weights to target device instead of CPU memory --- modules/sd_models.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd6..f912fe16 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -518,6 +518,13 @@ def send_model_to_cpu(m): devices.torch_gc() +def model_target_device(): + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + return devices.cpu + else: + return devices.device + + def send_model_to_device(m): if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) @@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + if shared.cmd_opts.no_half: + weight_dtype_conversion = None + else: + weight_dtype_conversion = { + 'first_stage_model': None, + '': torch.float16, + } + + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") -- cgit v1.2.3