diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-14 06:16:01 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-14 06:16:01 +0000 |
commit | 6d8dcdefa07d5f8f7e528046b0facdcc51185e60 (patch) | |
tree | c5298147907e890dc5e3094a9713f8e9a67c889e /modules/sd_hijack.py | |
parent | dc3906185656dae75fcefe96625b1dcd0d31579c (diff) | |
download | stable-diffusion-webui-gfx803-6d8dcdefa07d5f8f7e528046b0facdcc51185e60.tar.gz stable-diffusion-webui-gfx803-6d8dcdefa07d5f8f7e528046b0facdcc51185e60.tar.bz2 stable-diffusion-webui-gfx803-6d8dcdefa07d5f8f7e528046b0facdcc51185e60.zip |
initial SDXL refiner support
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 647cdfbe..2b274c18 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,21 +180,29 @@ class StableDiffusionModelHijack: def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
+ text_cond_models = []
+
for i in range(len(conditioner.embedders)):
embedder = conditioner.embedders[i]
typename = type(embedder).__name__
if typename == 'FrozenOpenCLIPEmbedder':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
- m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
- conditioner.embedders[i] = m.cond_stage_model
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenCLIPEmbedder':
- model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
+ model_embeddings = embedder.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
- m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
- conditioner.embedders[i] = m.cond_stage_model
+ conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenOpenCLIPEmbedder2':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+
+ if len(text_cond_models) == 1:
+ m.cond_stage_model = text_cond_models[0]
+ else:
+ m.cond_stage_model = conditioner
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
|