diff options
author | klimaleksus <klimaleksus@gmail.com> | 2023-05-28 20:09:59 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-28 20:09:59 +0000 |
commit | 4635f31270d1b5d41ad63815cb400b1ca73ea859 (patch) | |
tree | 51ab0950b10ffce16e9b61e989522c120db3205b /modules/textual_inversion/textual_inversion.py | |
parent | b957dcfece29c84ac0cfcd5a69475ff8684c531f (diff) | |
download | stable-diffusion-webui-gfx803-4635f31270d1b5d41ad63815cb400b1ca73ea859.tar.gz stable-diffusion-webui-gfx803-4635f31270d1b5d41ad63815cb400b1ca73ea859.tar.bz2 stable-diffusion-webui-gfx803-4635f31270d1b5d41ad63815cb400b1ca73ea859.zip |
Refactor EmbeddingDatabase.register_embedding() to allow unregistering
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d489ed1e..cbf94498 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -120,16 +120,29 @@ class EmbeddingDatabase: self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
- self.word_embeddings[embedding.name] = embedding
-
- ids = model.cond_stage_model.tokenize([embedding.name])[0]
+ return self.register_embedding_by_name(embedding, model, embedding.name)
+ def register_embedding_by_name(self, embedding, model, name):
+ ids = model.cond_stage_model.tokenize([name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
-
- self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
-
+ if name in self.word_embeddings:
+ # remove old one from the lookup list
+ lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name]
+ else:
+ lookup = self.ids_lookup[first_id]
+ if embedding is not None:
+ lookup += [(ids, embedding)]
+ self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True)
+ if embedding is None:
+ # unregister embedding with specified name
+ if name in self.word_embeddings:
+ del self.word_embeddings[name]
+ if len(self.ids_lookup[first_id])==0:
+ del self.ids_lookup[first_id]
+ return None
+ self.word_embeddings[name] = embedding
return embedding
def get_expected_shape(self):
|