diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-02 16:40:51 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-02 16:40:51 +0000 |
commit | 88ec0cf5571883d84abd09196652b3679e359f2e (patch) | |
tree | 030b33b3060c750ea5e8212049c293388f2fc3b3 /modules/textual_inversion/textual_inversion.py | |
parent | 53a3dc601fb734ce433505b1ca68770919106bad (diff) | |
download | stable-diffusion-webui-gfx803-88ec0cf5571883d84abd09196652b3679e359f2e.tar.gz stable-diffusion-webui-gfx803-88ec0cf5571883d84abd09196652b3679e359f2e.tar.bz2 stable-diffusion-webui-gfx803-88ec0cf5571883d84abd09196652b3679e359f2e.zip |
fix for incorrect embedding token length calculation (will break seeds that use embeddings, you're welcome!)
add option to input initialization text for embeddings
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c0baaace..0c50161d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -117,24 +117,21 @@ class EmbeddingDatabase: possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
- return None
+ return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
- return embedding
+ return embedding, len(ids)
- return None
+ return None, None
-
-def create_embedding(name, num_vectors_per_token):
- init_text = '*'
-
+def create_embedding(name, num_vectors_per_token, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
|