diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-12 20:52:43 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-12 20:52:43 +0000 |
commit | da464a3fb39ecc6ea7b22fe87271194480d8501c (patch) | |
tree | fd67d92762d0490d9d4784aaae3f2a3c2f31c6ca /modules/sd_models.py | |
parent | af081211ee93622473ee575de30fed2fd8263c09 (diff) | |
download | stable-diffusion-webui-gfx803-da464a3fb39ecc6ea7b22fe87271194480d8501c.tar.gz stable-diffusion-webui-gfx803-da464a3fb39ecc6ea7b22fe87271194480d8501c.tar.bz2 stable-diffusion-webui-gfx803-da464a3fb39ecc6ea7b22fe87271194480d8501c.zip |
SDXL support
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 8d639583..e4aae597 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -411,6 +411,7 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
+sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
class SdModelData:
@@ -445,6 +446,15 @@ class SdModelData: model_data = SdModelData()
+def get_empty_cond(sd_model):
+ if hasattr(sd_model, 'conditioner'):
+ d = sd_model.get_learned_conditioning([""])
+ return d['crossattn']
+ else:
+ return sd_model.cond_stage_model([""])
+
+
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -465,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
- clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
+ clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict
timer.record("find config")
@@ -517,7 +527,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad():
- sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+ sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt")
|