aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-02 16:40:51 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-10-02 16:40:51 +0000
commit88ec0cf5571883d84abd09196652b3679e359f2e (patch)
tree030b33b3060c750ea5e8212049c293388f2fc3b3 /modules/textual_inversion/textual_inversion.py
parent53a3dc601fb734ce433505b1ca68770919106bad (diff)
downloadstable-diffusion-webui-gfx803-88ec0cf5571883d84abd09196652b3679e359f2e.tar.gz
stable-diffusion-webui-gfx803-88ec0cf5571883d84abd09196652b3679e359f2e.tar.bz2
stable-diffusion-webui-gfx803-88ec0cf5571883d84abd09196652b3679e359f2e.zip
fix for incorrect embedding token length calculation (will break seeds that use embeddings, you're welcome!)
add option to input initialization text for embeddings
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py13
1 files changed, 5 insertions, 8 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index c0baaace..0c50161d 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -117,24 +117,21 @@ class EmbeddingDatabase:
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
- return None
+ return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
- return embedding
+ return embedding, len(ids)
- return None
+ return None, None
-
-def create_embedding(name, num_vectors_per_token):
- init_text = '*'
-
+def create_embedding(name, num_vectors_per_token, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):