From 016554e43740e0b7ded75e89255de81270de9d6c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 22 Aug 2023 18:49:08 +0300 Subject: add --medvram-sdxl --- modules/sd_models.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 27d15e66..4331853a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -517,7 +517,7 @@ def get_empty_cond(sd_model): def send_model_to_cpu(m): - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + if m.lowvram: lowvram.send_everything_to_cpu() else: m.to(devices.cpu) @@ -525,17 +525,17 @@ def send_model_to_cpu(m): devices.torch_gc() -def model_target_device(): - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: +def model_target_device(m): + if lowvram.is_needed(m): return devices.cpu else: return devices.device def send_model_to_device(m): - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) - else: + lowvram.apply(m) + + if not m.lowvram: m.to(shared.device) @@ -601,7 +601,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): '': torch.float16, } - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion): + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") @@ -743,7 +743,7 @@ def reload_model_weights(sd_model=None, info=None): script_callbacks.model_loaded_callback(sd_model) timer.record("script callbacks") - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + if not sd_model.lowvram: sd_model.to(devices.device) timer.record("move model to device") -- cgit v1.2.3