diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-16 09:04:53 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-16 09:04:53 +0000 |
commit | 0198eaec455157a7dc1c950708d1ec95bcf4629c (patch) | |
tree | 33d8e22448356c2f7c9455b3af17353ef497bbac /modules/sd_models.py | |
parent | 9d3dd64fe9e95873347710ca1df1f1e88d1908e1 (diff) | |
parent | 14cf434bc36d0ef31f31d4c6cd2bd15d7857d5c8 (diff) | |
download | stable-diffusion-webui-gfx803-0198eaec455157a7dc1c950708d1ec95bcf4629c.tar.gz stable-diffusion-webui-gfx803-0198eaec455157a7dc1c950708d1ec95bcf4629c.tar.bz2 stable-diffusion-webui-gfx803-0198eaec455157a7dc1c950708d1ec95bcf4629c.zip |
Merge pull request #11757 from AUTOMATIC1111/sdxl
SD XL support
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 060e0007..729f03d7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -289,6 +289,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None:
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
+ model.is_sdxl = hasattr(model, 'conditioner')
+ if model.is_sdxl:
+ sd_models_xl.extend_sdxl(model)
+
model.load_state_dict(state_dict, strict=False)
del state_dict
timer.record("apply weights to model")
@@ -334,7 +338,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.sd_checkpoint_info = checkpoint_info
shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
- model.logvar = model.logvar.to(devices.device) # fix for training
+ if hasattr(model, 'logvar'):
+ model.logvar = model.logvar.to(devices.device) # fix for training
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
@@ -391,10 +396,11 @@ def repair_config(sd_config): if not hasattr(sd_config.model.params, "use_ema"):
sd_config.model.params.use_ema = False
- if shared.cmd_opts.no_half:
- sd_config.model.params.unet_config.params.use_fp16 = False
- elif shared.cmd_opts.upcast_sampling:
- sd_config.model.params.unet_config.params.use_fp16 = True
+ if hasattr(sd_config.model.params, 'unet_config'):
+ if shared.cmd_opts.no_half:
+ sd_config.model.params.unet_config.params.use_fp16 = False
+ elif shared.cmd_opts.upcast_sampling:
+ sd_config.model.params.unet_config.params.use_fp16 = True
if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"
@@ -407,6 +413,8 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
+sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
+sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
@@ -441,6 +449,15 @@ class SdModelData: model_data = SdModelData()
+def get_empty_cond(sd_model):
+ if hasattr(sd_model, 'conditioner'):
+ d = sd_model.get_learned_conditioning([""])
+ return d['crossattn']
+ else:
+ return sd_model.cond_stage_model([""])
+
+
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -461,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
- clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
+ clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
timer.record("find config")
@@ -513,7 +530,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks")
with devices.autocast(), torch.no_grad():
- sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""])
+ sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
timer.record("calculate empty prompt")
|