From 706d5944a075a6523ea7f00165d630efc085ca22 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 13:38:57 +0300 Subject: let user choose his own prompt token count limit --- modules/sd_hijack.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d68f89cc..340329c0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -18,7 +18,6 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward - def apply_optimizations(): undo_optimizations() @@ -83,7 +82,7 @@ class StableDiffusionModelHijack: layer.padding_mode = 'circular' if enable else 'zeros' def tokenize(self, text): - max_length = self.clip.max_length - 2 + max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, max_length @@ -94,7 +93,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.wrapped = wrapped self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer - self.max_length = wrapped.max_length self.token_mults = {} tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] @@ -116,7 +114,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length + maxlen = opts.max_prompt_tokens if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) @@ -191,7 +189,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length + maxlen = self.wrapped.max_length # you get to stay at 77 used_custom_terms = [] remade_batch_tokens = [] overflowing_words = [] @@ -268,8 +266,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76] + position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) + tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens) + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise -- cgit v1.2.3