aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_models_config.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-16 09:04:53 +0000
committerGitHub <noreply@github.com>2023-07-16 09:04:53 +0000
commit0198eaec455157a7dc1c950708d1ec95bcf4629c (patch)
tree33d8e22448356c2f7c9455b3af17353ef497bbac /modules/sd_models_config.py
parent9d3dd64fe9e95873347710ca1df1f1e88d1908e1 (diff)
parent14cf434bc36d0ef31f31d4c6cd2bd15d7857d5c8 (diff)
downloadstable-diffusion-webui-gfx803-0198eaec455157a7dc1c950708d1ec95bcf4629c.tar.gz
stable-diffusion-webui-gfx803-0198eaec455157a7dc1c950708d1ec95bcf4629c.tar.bz2
stable-diffusion-webui-gfx803-0198eaec455157a7dc1c950708d1ec95bcf4629c.zip
Merge pull request #11757 from AUTOMATIC1111/sdxl
SD XL support
Diffstat (limited to 'modules/sd_models_config.py')
-rw-r--r--modules/sd_models_config.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 9bfe1237..8266fa39 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -6,12 +6,15 @@ from modules import shared, paths, sd_disable_initialization
sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
+sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")
config_default = shared.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
+config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
+config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")
@@ -68,7 +71,11 @@ def guess_model_config_from_state_dict(sd, filename):
diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)
sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)
- if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
+ if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:
+ return config_sdxl
+ if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:
+ return config_sdxl_refiner
+ elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:
return config_depth_model
elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:
return config_unclip