From 0815c45bcdec0a2e5c60bdd5b33d95813d799c01 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 10:44:17 +0300 Subject: send weights to target device instead of CPU memory --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd6..b01d44c5 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -579,7 +579,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.device): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") -- cgit v1.2.3 From 57e59c14c8a13a99d6422597d27d92ad10a51ca1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 11:28:00 +0300 Subject: Revert "send weights to target device instead of CPU memory" This reverts commit 0815c45bcdec0a2e5c60bdd5b33d95813d799c01. --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b01d44c5..f6fbdcd6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -579,7 +579,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.device): + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") -- cgit v1.2.3 From eaba3d7349c6f0e151be66ade3fdc848d693a10d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 12:11:01 +0300 Subject: send weights to target device instead of CPU memory --- modules/sd_models.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd6..f912fe16 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -518,6 +518,13 @@ def send_model_to_cpu(m): devices.torch_gc() +def model_target_device(): + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + return devices.cpu + else: + return devices.device + + def send_model_to_device(m): if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) @@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + if shared.cmd_opts.no_half: + weight_dtype_conversion = None + else: + weight_dtype_conversion = { + 'first_stage_model': None, + '': torch.float16, + } + + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") -- cgit v1.2.3 From 0dc74545c0b5510911757ed9f2be703aab58f014 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 17 Aug 2023 07:54:07 +0300 Subject: resolve the issue with loading fp16 checkpoints while using --no-half --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f912fe16..685585b1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -343,7 +343,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.to(memory_format=torch.channels_last) timer.record("apply channels_last") - if not shared.cmd_opts.no_half: + if shared.cmd_opts.no_half: + model.float() + timer.record("apply float()") + else: vae = model.first_stage_model depth_model = getattr(model, 'depth_model', None) -- cgit v1.2.3