From 6d8dcdefa07d5f8f7e528046b0facdcc51185e60 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:16:01 +0300 Subject: initial SDXL refiner support --- modules/sd_hijack.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 647cdfbe..2b274c18 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,21 +180,29 @@ class StableDiffusionModelHijack: def hijack(self, m): conditioner = getattr(m, 'conditioner', None) if conditioner: + text_cond_models = [] + for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] typename = type(embedder).__name__ if typename == 'FrozenOpenCLIPEmbedder': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) - m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenCLIPEmbedder': - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings = embedder.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenOpenCLIPEmbedder2': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) + + if len(text_cond_models) == 1: + m.cond_stage_model = text_cond_models[0] + else: + m.cond_stage_model = conditioner if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings -- cgit v1.2.3