aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r--modules/sd_hijack.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 6b5aae4b..f5615967 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
+import sgm.modules.attention
+import sgm.modules.diffusionmodules.model
+import sgm.modules.diffusionmodules.openaimodel
+import sgm.modules.encoders.modules
+
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
@@ -56,6 +61,9 @@ def apply_optimizations(option=None):
ldm.modules.diffusionmodules.model.nonlinearity = silu
ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+ sgm.modules.diffusionmodules.model.nonlinearity = silu
+ sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th
+
if current_optimizer is not None:
current_optimizer.undo()
current_optimizer = None
@@ -89,6 +97,10 @@ def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+ sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
+ sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
+ sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
+
def fix_checkpoint():
"""checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want
@@ -168,6 +180,32 @@ class StableDiffusionModelHijack:
undo_optimizations()
def hijack(self, m):
+ conditioner = getattr(m, 'conditioner', None)
+ if conditioner:
+ text_cond_models = []
+
+ for i in range(len(conditioner.embedders)):
+ embedder = conditioner.embedders[i]
+ typename = type(embedder).__name__
+ if typename == 'FrozenOpenCLIPEmbedder':
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+ if typename == 'FrozenCLIPEmbedder':
+ model_embeddings = embedder.transformer.text_model.embeddings
+ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+ if typename == 'FrozenOpenCLIPEmbedder2':
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
+ text_cond_models.append(conditioner.embedders[i])
+
+ if len(text_cond_models) == 1:
+ m.cond_stage_model = text_cond_models[0]
+ else:
+ m.cond_stage_model = conditioner
+
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)