aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 729f03d7..fb31a793 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -290,6 +290,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
model.is_sdxl = hasattr(model, 'conditioner')
+ model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
+ model.is_sd1 = not model.is_sdxl and not model.is_sd2
+
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
@@ -323,7 +326,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
timer.record("apply half()")
- devices.dtype_unet = model.model.diffusion_model.dtype
+ devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
@@ -491,7 +494,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = None
try:
- with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
sd_model = instantiate_from_config(sd_config.model)
except Exception:
pass