aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--modules/sd_hijack.py8
-rw-r--r--modules/sd_hijack_clip.py2
-rw-r--r--modules/sd_models_xl.py9
-rw-r--r--modules/textual_inversion/textual_inversion.py19
4 files changed, 29 insertions, 9 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index c8fdd4f1..cfa5f0eb 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -197,7 +197,7 @@ class StableDiffusionModelHijack:
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenOpenCLIPEmbedder2':
- embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
+ embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i])
@@ -292,10 +292,11 @@ class StableDiffusionModelHijack:
class EmbeddingsWithFixes(torch.nn.Module):
- def __init__(self, wrapped, embeddings):
+ def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
+ self.textual_inversion_key = textual_inversion_key
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
@@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
- emb = devices.cond_cast_unet(embedding.vec)
+ vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
+ emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index 16a5500e..2f9d569b 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
position += 1
continue
- emb_len = int(embedding.vec.shape[0])
+ emb_len = int(embedding.vectors)
if len(chunk.tokens) + emb_len > self.chunk_length:
next_chunk()
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index 40559208..bc219508 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text,
return torch.cat(res, dim=1)
+def tokenize(self: sgm.modules.GeneralConditioner, texts):
+ for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
+ return embedder.tokenize(texts)
+
+ raise AssertionError('no tokenizer available')
+
+
+
def process_texts(self, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
return embedder.process_texts(texts)
@@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count):
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
+sgm.modules.GeneralConditioner.tokenize = tokenize
sgm.modules.GeneralConditioner.process_texts = process_texts
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 6166c76f..4713bc2d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -181,29 +181,38 @@ class EmbeddingDatabase:
else:
return
+
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
+ elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
+ vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
+ shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
+ vectors = data['clip_g'].shape[0]
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ shape = vec.shape[-1]
+ vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
- vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- embedding.vectors = vec.shape[0]
- embedding.shape = vec.shape[-1]
+ embedding.vectors = vectors
+ embedding.shape = shape
embedding.filename = path
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')