diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-05-31 16:29:47 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-31 16:29:47 +0000 |
commit | 881de0df38c1fa6d0d61f7bc6fc93c100a9f35d0 (patch) | |
tree | 7079024b0ffc02744caf1083a46268bb3cc941d2 /modules/textual_inversion | |
parent | 670195d7202d49bec0c3d489cd7bbaac9dcd5901 (diff) | |
parent | 4635f31270d1b5d41ad63815cb400b1ca73ea859 (diff) | |
download | stable-diffusion-webui-gfx803-881de0df38c1fa6d0d61f7bc6fc93c100a9f35d0.tar.gz stable-diffusion-webui-gfx803-881de0df38c1fa6d0d61f7bc6fc93c100a9f35d0.tar.bz2 stable-diffusion-webui-gfx803-881de0df38c1fa6d0d61f7bc6fc93c100a9f35d0.zip |
Merge pull request #10803 from klimaleksus/refactoring-for-embedding-merge
Refactor EmbeddingDatabase.register_embedding() to allow unregistering
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index a040a988..b3dcb140 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -119,16 +119,29 @@ class EmbeddingDatabase: self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
- self.word_embeddings[embedding.name] = embedding
-
- ids = model.cond_stage_model.tokenize([embedding.name])[0]
+ return self.register_embedding_by_name(embedding, model, embedding.name)
+ def register_embedding_by_name(self, embedding, model, name):
+ ids = model.cond_stage_model.tokenize([name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
-
- self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
-
+ if name in self.word_embeddings:
+ # remove old one from the lookup list
+ lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
+ else:
+ lookup = self.ids_lookup[first_id]
+ if embedding is not None:
+ lookup += [(ids, embedding)]
+ self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
+ if embedding is None:
+ # unregister embedding with specified name
+ if name in self.word_embeddings:
+ del self.word_embeddings[name]
+ if len(self.ids_lookup[first_id])==0:
+ del self.ids_lookup[first_id]
+ return None
+ self.word_embeddings[name] = embedding
return embedding
def get_expected_shape(self):
|