diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-15 19:29:53 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-15 19:29:53 +0000 |
commit | 8e2aeee4a127b295bfc880800e4a312e0f049b85 (patch) | |
tree | a1354cb94c09ad00216e5eadda48259be494fea9 /modules/sd_hijack_clip.py | |
parent | 205991df7826429e6183fc4afbbda3d321c9fee4 (diff) | |
download | stable-diffusion-webui-gfx803-8e2aeee4a127b295bfc880800e4a312e0f049b85.tar.gz stable-diffusion-webui-gfx803-8e2aeee4a127b295bfc880800e4a312e0f049b85.tar.bz2 stable-diffusion-webui-gfx803-8e2aeee4a127b295bfc880800e4a312e0f049b85.zip |
add BREAK keyword to end current text chunk and start the next
Diffstat (limited to 'modules/sd_hijack_clip.py')
-rw-r--r-- | modules/sd_hijack_clip.py | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 852afc66..9fa5c5c5 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -96,13 +96,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): token_count = 0
last_comma = -1
- def next_chunk():
- """puts current chunk into the list of results and produces the next one - empty"""
+ def next_chunk(is_last=False):
+ """puts current chunk into the list of results and produces the next one - empty;
+ if is_last is true, tokens <end-of-text> tokens at the end won't add to token_count"""
nonlocal token_count
nonlocal last_comma
nonlocal chunk
- token_count += len(chunk.tokens)
+ if is_last:
+ token_count += len(chunk.tokens)
+ else:
+ token_count += self.chunk_length
+
to_add = self.chunk_length - len(chunk.tokens)
if to_add > 0:
chunk.tokens += [self.id_end] * to_add
@@ -116,6 +121,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk = PromptChunk()
for tokens, (text, weight) in zip(tokenized, parsed):
+ if text == 'BREAK' and weight == -1:
+ next_chunk()
+ continue
+
position = 0
while position < len(tokens):
token = tokens[position]
@@ -159,7 +168,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): position += embedding_length_in_tokens
if len(chunk.tokens) > 0 or len(chunks) == 0:
- next_chunk()
+ next_chunk(is_last=True)
return chunks, token_count
|