diff options
Diffstat (limited to 'modules/sd_hijack.py')
-rw-r--r-- | modules/sd_hijack.py | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 827bf304..f07ec041 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion
-from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
+from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
import ldm.modules.attention
@@ -37,6 +37,8 @@ def apply_optimizations(): def undo_optimizations():
+ from modules.hypernetworks import hypernetwork
+
ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
@@ -107,6 +109,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.tokenizer = wrapped.tokenizer
self.token_mults = {}
+ self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
+
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
@@ -136,6 +140,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): fixes = []
remade_tokens = []
multipliers = []
+ last_comma = -1
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
@@ -144,6 +149,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
+ if token == self.comma_token:
+ last_comma = len(remade_tokens)
+ elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack:
+ last_comma += 1
+ reloc_tokens = remade_tokens[last_comma:]
+ reloc_mults = multipliers[last_comma:]
+
+ remade_tokens = remade_tokens[:last_comma]
+ length = len(remade_tokens)
+
+ rem = int(math.ceil(length / 75)) * 75 - length
+ remade_tokens += [id_end] * rem + reloc_tokens
+ multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
@@ -284,7 +303,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while max(map(len, remade_batch_tokens)) != 0:
rem_tokens = [x[75:] for x in remade_batch_tokens]
rem_multipliers = [x[75:] for x in batch_multipliers]
-
+
self.hijack.fixes = []
for unfiltered in hijack_fixes:
fixes = []
|