diff options
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 95a17093..eb679ef9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -29,7 +29,7 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At # new memory efficient cross attention blocks do not support hypernets and we already
# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention
ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention
-ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
+# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
ldm.modules.attention.print = lambda *args: None
@@ -82,17 +82,24 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
- if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
+
+ if shared.text_model_name == "XLMR-Large":
+ model_embeddings = m.cond_stage_model.roberta.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
+ m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+
+ elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+ apply_optimizations()
elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
-
+ apply_optimizations()
+
self.clip = m.cond_stage_model
-
- apply_optimizations()
+
fix_checkpoint()
def flatten(el):
@@ -133,8 +140,8 @@ class StableDiffusionModelHijack: def tokenize(self, text):
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
- return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
+ return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count)
class EmbeddingsWithFixes(torch.nn.Module):
|