diff options
author | Shondoit <shondoit@gmail.com> | 2023-01-12 08:22:29 +0000 |
---|---|---|
committer | Shondoit <shondoit@gmail.com> | 2023-01-12 08:22:29 +0000 |
commit | d52a80f7f7da160c73afd067c8f1bf491391f994 (patch) | |
tree | 6707662f0f311d5e5346d436b543333e66021af6 /modules/textual_inversion | |
parent | 0b8911d883118daa54f7735c5b753b5575d9f943 (diff) | |
download | stable-diffusion-webui-gfx803-d52a80f7f7da160c73afd067c8f1bf491391f994.tar.gz stable-diffusion-webui-gfx803-d52a80f7f7da160c73afd067c8f1bf491391f994.tar.bz2 stable-diffusion-webui-gfx803-d52a80f7f7da160c73afd067c8f1bf491391f994.zip |
Allow creation of zero vectors for TI
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index b915b091..853246a6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -248,11 +248,14 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): with devices.autocast():
cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
- embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
+ #cond_model expects at least some text, so we provide '*' as backup.
+ embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
- for i in range(num_vectors_per_token):
- vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ #Only copy if we provided an init_text, otherwise keep vectors as zeros
+ if init_text:
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
|