diff options
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 77 |
1 files changed, 61 insertions, 16 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c8fdd4f1..e139d996 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,15 +2,15 @@ import torch from torch.nn.functional import silu
from types import MethodType
-import modules.textual_inversion.textual_inversion
-from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
+from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
+import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
@@ -30,12 +30,20 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
-ldm.modules.attention.print = lambda *args: None
-ldm.modules.diffusionmodules.model.print = lambda *args: None
+ldm.modules.attention.print = shared.ldm_print
+ldm.modules.diffusionmodules.model.print = shared.ldm_print
+ldm.util.print = shared.ldm_print
+ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
+ldm_patched_forward = sd_unet.create_unet_forward(ldm.modules.diffusionmodules.openaimodel.UNetModel.forward)
+ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", ldm_patched_forward)
+
+sgm_patched_forward = sd_unet.create_unet_forward(sgm.modules.diffusionmodules.openaimodel.UNetModel.forward)
+sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sgm_patched_forward)
+
def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
@@ -164,12 +172,13 @@ class StableDiffusionModelHijack: clip = None
optimization_method = None
- embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
-
def __init__(self):
+ import modules.textual_inversion.textual_inversion
+
self.extra_generation_params = {}
self.comments = []
+ self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def apply_optimizations(self, option=None):
@@ -179,6 +188,20 @@ class StableDiffusionModelHijack: errors.display(e, "applying cross attention optimization")
undo_optimizations()
+ def convert_sdxl_to_ssd(self, m):
+ """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)"""
+
+ delattr(m.model.diffusion_model.middle_block, '1')
+ delattr(m.model.diffusion_model.middle_block, '2')
+ for i in ['9', '8', '7', '6', '5', '4']:
+ delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i)
+ delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1')
+ delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1')
+ devices.torch_gc()
+
def hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
if conditioner:
@@ -197,7 +220,7 @@ class StableDiffusionModelHijack: conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenOpenCLIPEmbedder2':
- embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
@@ -206,7 +229,7 @@ class StableDiffusionModelHijack: else:
m.cond_stage_model = conditioner
- if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
+ if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation or type(m.cond_stage_model) == xlmr_m18.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)
@@ -237,13 +260,34 @@ class StableDiffusionModelHijack: self.layers = flatten(m)
- if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
- ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
+ import modules.models.diffusion.ddpm_edit
+
+ if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
+ sd_unet.original_forward = ldm_original_forward
+ elif isinstance(m, modules.models.diffusion.ddpm_edit.LatentDiffusion):
+ sd_unet.original_forward = ldm_original_forward
+ elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
+ sd_unet.original_forward = sgm_original_forward
+ else:
+ sd_unet.original_forward = None
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m):
- if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
+ conditioner = getattr(m, 'conditioner', None)
+ if conditioner:
+ for i in range(len(conditioner.embedders)):
+ embedder = conditioner.embedders[i]
+ if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)):
+ embedder.wrapped.model.token_embedding = embedder.wrapped.model.token_embedding.wrapped
+ conditioner.embedders[i] = embedder.wrapped
+ if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords):
+ embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped
+ conditioner.embedders[i] = embedder.wrapped
+
+ if hasattr(m, 'cond_stage_model'):
+ delattr(m, 'cond_stage_model')
+
+ elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
@@ -263,7 +307,6 @@ class StableDiffusionModelHijack: self.layers = None
self.clip = None
- ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
def apply_circular(self, enable):
if self.circular_enabled == enable:
@@ -292,10 +335,11 @@ class StableDiffusionModelHijack: class EmbeddingsWithFixes(torch.nn.Module):
- def __init__(self, wrapped, embeddings):
+ def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
+ self.textual_inversion_key = textual_inversion_key
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
@@ -309,7 +353,8 @@ class EmbeddingsWithFixes(torch.nn.Module): vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
- emb = devices.cond_cast_unet(embedding.vec)
+ vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
+ emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
|