diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-22 15:49:08 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-22 15:49:08 +0000 |
commit | 016554e43740e0b7ded75e89255de81270de9d6c (patch) | |
tree | 127ecb9d92ce53f8e351270e8ef1c27784af3088 /modules/sd_models.py | |
parent | bb7dd7b64668d4b645dba38a3bc52be452d14eb8 (diff) | |
download | stable-diffusion-webui-gfx803-016554e43740e0b7ded75e89255de81270de9d6c.tar.gz stable-diffusion-webui-gfx803-016554e43740e0b7ded75e89255de81270de9d6c.tar.bz2 stable-diffusion-webui-gfx803-016554e43740e0b7ded75e89255de81270de9d6c.zip |
add --medvram-sdxl
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 27d15e66..4331853a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -517,7 +517,7 @@ def get_empty_cond(sd_model): def send_model_to_cpu(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if m.lowvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
@@ -525,17 +525,17 @@ def send_model_to_cpu(m): devices.torch_gc()
-def model_target_device():
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+def model_target_device(m):
+ if lowvram.is_needed(m):
return devices.cpu
else:
return devices.device
def send_model_to_device(m):
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
- else:
+ lowvram.apply(m)
+
+ if not m.lowvram:
m.to(shared.device)
@@ -601,7 +601,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): '': torch.float16,
}
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
@@ -743,7 +743,7 @@ def reload_model_weights(sd_model=None, info=None): script_callbacks.model_loaded_callback(sd_model)
timer.record("script callbacks")
- if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
+ if not sd_model.lowvram:
sd_model.to(devices.device)
timer.record("move model to device")
|