diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-16 09:11:01 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-16 09:11:01 +0000 |
commit | eaba3d7349c6f0e151be66ade3fdc848d693a10d (patch) | |
tree | 45e4741b441756a2ddadfed70f4fbda4f6cb3022 /modules/sd_models.py | |
parent | 57e59c14c8a13a99d6422597d27d92ad10a51ca1 (diff) | |
download | stable-diffusion-webui-gfx803-eaba3d7349c6f0e151be66ade3fdc848d693a10d.tar.gz stable-diffusion-webui-gfx803-eaba3d7349c6f0e151be66ade3fdc848d693a10d.tar.bz2 stable-diffusion-webui-gfx803-eaba3d7349c6f0e151be66ade3fdc848d693a10d.zip |
send weights to target device instead of CPU memory
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd6..f912fe16 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -518,6 +518,13 @@ def send_model_to_cpu(m): devices.torch_gc()
+def model_target_device():
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ return devices.cpu
+ else:
+ return devices.device
+
+
def send_model_to_device(m):
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
@@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model")
- with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
+ if shared.cmd_opts.no_half:
+ weight_dtype_conversion = None
+ else:
+ weight_dtype_conversion = {
+ 'first_stage_model': None,
+ '': torch.float16,
+ }
+
+ with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
timer.record("load weights from state dict")
|