From ce6911158b5b2f9cf79b405a1f368f875492044d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 26 Nov 2022 16:10:46 +0300 Subject: Add support Stable Diffusion 2.0 --- modules/sd_hijack_clip.py | 301 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) create mode 100644 modules/sd_hijack_clip.py (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py new file mode 100644 index 00000000..b451d1cf --- /dev/null +++ b/modules/sd_hijack_clip.py @@ -0,0 +1,301 @@ +import math + +import torch + +from modules import prompt_parser, devices +from modules.shared import opts + + +def get_target_prompt_token_count(token_count): + return math.ceil(max(token_count, 1) / 75) * 75 + + +class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): + def __init__(self, wrapped, hijack): + super().__init__() + self.wrapped = wrapped + self.hijack = hijack + + def tokenize(self, texts): + raise NotImplementedError + + def encode_with_transformers(self, tokens): + raise NotImplementedError + + def encode_embedding_init_text(self, init_text, nvpt): + raise NotImplementedError + + def tokenize_line(self, line, used_custom_terms, hijack_comments): + if opts.enable_emphasis: + parsed = prompt_parser.parse_prompt_attention(line) + else: + parsed = [[line, 1.0]] + + tokenized = self.tokenize([text for text, _ in parsed]) + + fixes = [] + remade_tokens = [] + multipliers = [] + last_comma = -1 + + for tokens, (text, weight) in zip(tokenized, parsed): + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + if token == self.comma_token: + last_comma = len(remade_tokens) + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + last_comma += 1 + reloc_tokens = remade_tokens[last_comma:] + reloc_mults = multipliers[last_comma:] + + remade_tokens = remade_tokens[:last_comma] + length = len(remade_tokens) + + rem = int(math.ceil(length / 75)) * 75 - length + remade_tokens += [self.id_end] * rem + reloc_tokens + multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + + if embedding is None: + remade_tokens.append(token) + multipliers.append(weight) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + iteration = len(remade_tokens) // 75 + if (len(remade_tokens) + emb_len) // 75 != iteration: + rem = (75 * (iteration + 1) - len(remade_tokens)) + remade_tokens += [self.id_end] * rem + multipliers += [1.0] * rem + iteration += 1 + fixes.append((iteration, (len(remade_tokens) % 75, embedding))) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + token_count = len(remade_tokens) + prompt_target_length = get_target_prompt_token_count(token_count) + tokens_to_add = prompt_target_length - len(remade_tokens) + + remade_tokens = remade_tokens + [self.id_end] * tokens_to_add + multipliers = multipliers + [1.0] * tokens_to_add + + return remade_tokens, fixes, multipliers, token_count + + def process_text(self, texts): + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_multipliers = [] + for line in texts: + if line in cache: + remade_tokens, fixes, multipliers = cache[line] + else: + remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + token_count = max(current_token_count, token_count) + + cache[line] = (remade_tokens, fixes, multipliers) + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + def process_text_old(self, texts): + id_start = self.id_start + id_end = self.id_end + maxlen = self.wrapped.max_length # you get to stay at 77 + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_tokens = self.tokenize(texts) + batch_multipliers = [] + for tokens in batch_tokens: + tuple_tokens = tuple(tokens) + + if tuple_tokens in cache: + remade_tokens, fixes, multipliers = cache[tuple_tokens] + else: + fixes = [] + remade_tokens = [] + multipliers = [] + mult = 1.0 + + i = 0 + while i < len(tokens): + token = tokens[i] + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + + mult_change = self.token_mults.get(token) if opts.enable_emphasis else None + if mult_change is not None: + mult *= mult_change + i += 1 + elif embedding is None: + remade_tokens.append(token) + multipliers.append(mult) + i += 1 + else: + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += embedding_length_in_tokens + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + def forward(self, text): + use_old = opts.use_old_emphasis_implementation + if use_old: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) + else: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + + self.hijack.comments += hijack_comments + + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + if use_old: + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) + + z = None + i = 0 + while max(map(len, remade_batch_tokens)) != 0: + rem_tokens = [x[75:] for x in remade_batch_tokens] + rem_multipliers = [x[75:] for x in batch_multipliers] + + self.hijack.fixes = [] + for unfiltered in hijack_fixes: + fixes = [] + for fix in unfiltered: + if fix[0] == i: + fixes.append(fix[1]) + self.hijack.fixes.append(fixes) + + tokens = [] + multipliers = [] + for j in range(len(remade_batch_tokens)): + if len(remade_batch_tokens[j]) > 0: + tokens.append(remade_batch_tokens[j][:75]) + multipliers.append(batch_multipliers[j][:75]) + else: + tokens.append([self.id_end] * 75) + multipliers.append([1.0] * 75) + + z1 = self.process_tokens(tokens, multipliers) + z = z1 if z is None else torch.cat((z, z1), axis=-2) + + remade_batch_tokens = rem_tokens + batch_multipliers = rem_multipliers + i += 1 + + return z + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens] + batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + + tokens = torch.asarray(remade_batch_tokens).to(devices.device) + + if self.id_end != self.id_pad: + for batch_pos in range(len(remade_batch_tokens)): + index = remade_batch_tokens[batch_pos].index(self.id_end) + tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad + + z = self.encode_with_transformers(tokens) + + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise + batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] + batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device) + original_mean = z.mean() + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + new_mean = z.mean() + z *= original_mean / new_mean + + return z + + +class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + self.tokenizer = wrapped.tokenizer + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + + self.token_mults = {} + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + for text, ident in tokens_with_parens: + mult = 1.0 + for c in text: + if c == '[': + mult /= 1.1 + if c == ']': + mult *= 1.1 + if c == '(': + mult *= 1.1 + if c == ')': + mult /= 1.1 + + if mult != 1.0: + self.token_mults[ident] = mult + + self.id_start = self.wrapped.tokenizer.bos_token_id + self.id_end = self.wrapped.tokenizer.eos_token_id + self.id_pad = self.id_end + + def tokenize(self, texts): + tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + return tokenized + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + + if opts.CLIP_stop_at_last_layers > 1: + z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] + z = self.wrapped.transformer.text_model.final_layer_norm(z) + else: + z = outputs.last_hidden_state + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + embedding_layer = self.wrapped.transformer.text_model.embeddings + ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) + + return embedded -- cgit v1.2.3 From 52cc83d36b7663a77b79fd2258d2ca871af73e55 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Wed, 30 Nov 2022 14:56:12 +0800 Subject: fix bugs Signed-off-by: zhaohu xing <920232796@qq.com> --- modules/sd_hijack_clip.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index b451d1cf..9ea6e1ce 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -4,7 +4,7 @@ import torch from modules import prompt_parser, devices from modules.shared import opts - +import modules.shared as shared def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 @@ -177,6 +177,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): + if shared.text_model_name == "XLMR-Large": + return self.wrapped.encode(text) + use_old = opts.use_old_emphasis_implementation if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) @@ -254,7 +257,10 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) self.tokenizer = wrapped.tokenizer - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + if shared.text_model_name == "XLMR-Large": + self.comma_token = None + else : + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] self.token_mults = {} tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] -- cgit v1.2.3 From f34c7341720fb2059992926c9f9ae6ff25f7385b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 31 Dec 2022 18:06:35 +0300 Subject: alt-diffusion integration --- modules/sd_hijack_clip.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 9ea6e1ce..6ec50cca 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -4,7 +4,6 @@ import torch from modules import prompt_parser, devices from modules.shared import opts -import modules.shared as shared def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 @@ -177,9 +176,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): - if shared.text_model_name == "XLMR-Large": - return self.wrapped.encode(text) - use_old = opts.use_old_emphasis_implementation if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) @@ -257,13 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): def __init__(self, wrapped, hijack): super().__init__(wrapped, hijack) self.tokenizer = wrapped.tokenizer - if shared.text_model_name == "XLMR-Large": - self.comma_token = None - else : - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + + vocab = self.tokenizer.get_vocab() + + self.comma_token = vocab.get(',', None) self.token_mults = {} - tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 for c in text: -- cgit v1.2.3 From 210449b374d522c94a67fe54289a9eb515933a9f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 02:41:15 +0300 Subject: fix 'RuntimeError: Expected all tensors to be on the same device' error preventing models from loading on lowvram/medvram. --- modules/sd_hijack_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 6ec50cca..ca92b142 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -298,6 +298,6 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): def encode_embedding_init_text(self, init_text, nvpt): embedding_layer = self.wrapped.transformer.text_model.embeddings ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] - embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) + embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) return embedded -- cgit v1.2.3 From 79e39fae6110c20a3ee6255e2841c877f65e8cbd Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 01:45:28 +0300 Subject: CLIP hijack rework --- modules/sd_hijack_clip.py | 348 +++++++++++++++++++++++----------------------- 1 file changed, 171 insertions(+), 177 deletions(-) (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index ca92b142..ac3020d7 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -1,12 +1,28 @@ import math +from collections import namedtuple import torch from modules import prompt_parser, devices from modules.shared import opts -def get_target_prompt_token_count(token_count): - return math.ceil(max(token_count, 1) / 75) * 75 + +class PromptChunk: + """ + This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt. + If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary. + Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token, + so just 75 tokens from prompt. + """ + + def __init__(self): + self.tokens = [] + self.multipliers = [] + self.fixes = [] + + +PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) +"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk""" class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): @@ -14,17 +30,49 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): super().__init__() self.wrapped = wrapped self.hijack = hijack + self.chunk_length = 75 + + def empty_chunk(self): + """creates an empty PromptChunk and returns it""" + + chunk = PromptChunk() + chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) + chunk.multipliers = [1.0] * (self.chunk_length + 2) + return chunk + + def get_target_prompt_token_count(self, token_count): + """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented""" + + return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length def tokenize(self, texts): + """Converts a batch of texts into a batch of token ids""" + raise NotImplementedError def encode_with_transformers(self, tokens): + """ + converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; + All python lists with tokens are assumed to have same length, usually 77. + if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on + model - can be 768 and 1024 + """ + raise NotImplementedError def encode_embedding_init_text(self, init_text, nvpt): + """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through + transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned.""" + raise NotImplementedError - def tokenize_line(self, line, used_custom_terms, hijack_comments): + def tokenize_line(self, line): + """ + this transforms a single prompt into a list of PromptChunk objects - as many as needed to + represent the prompt. + Returns the list and the total number of tokens in the prompt. + """ + if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) else: @@ -32,205 +80,152 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): tokenized = self.tokenize([text for text, _ in parsed]) - fixes = [] - remade_tokens = [] - multipliers = [] + chunks = [] + chunk = PromptChunk() + token_count = 0 last_comma = -1 - for tokens, (text, weight) in zip(tokenized, parsed): - i = 0 - while i < len(tokens): - token = tokens[i] + def next_chunk(): + """puts current chunk into the list of results and produces the next one - empty""" + nonlocal token_count + nonlocal last_comma + nonlocal chunk + + token_count += len(chunk.tokens) + to_add = self.chunk_length - len(chunk.tokens) + if to_add > 0: + chunk.tokens += [self.id_end] * to_add + chunk.multipliers += [1.0] * to_add - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end] + chunk.multipliers = [1.0] + chunk.multipliers + [1.0] + + last_comma = -1 + chunks.append(chunk) + chunk = PromptChunk() + + for tokens, (text, weight) in zip(tokenized, parsed): + position = 0 + while position < len(tokens): + token = tokens[position] if token == self.comma_token: - last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: - last_comma += 1 - reloc_tokens = remade_tokens[last_comma:] - reloc_mults = multipliers[last_comma:] + last_comma = len(chunk.tokens) + + # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack + # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next. + elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: + break_location = last_comma + 1 + + reloc_tokens = chunk.tokens[break_location:] + reloc_mults = chunk.multipliers[break_location:] - remade_tokens = remade_tokens[:last_comma] - length = len(remade_tokens) + chunk.tokens = chunk.tokens[:break_location] + chunk.multipliers = chunk.multipliers[:break_location] - rem = int(math.ceil(length / 75)) * 75 - length - remade_tokens += [self.id_end] * rem + reloc_tokens - multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + next_chunk() + chunk.tokens = reloc_tokens + chunk.multipliers = reloc_mults + if len(chunk.tokens) == self.chunk_length: + next_chunk() + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) if embedding is None: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - iteration = len(remade_tokens) // 75 - if (len(remade_tokens) + emb_len) // 75 != iteration: - rem = (75 * (iteration + 1) - len(remade_tokens)) - remade_tokens += [self.id_end] * rem - multipliers += [1.0] * rem - iteration += 1 - fixes.append((iteration, (len(remade_tokens) % 75, embedding))) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - token_count = len(remade_tokens) - prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) - - remade_tokens = remade_tokens + [self.id_end] * tokens_to_add - multipliers = multipliers + [1.0] * tokens_to_add - - return remade_tokens, fixes, multipliers, token_count - - def process_text(self, texts): - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] + chunk.tokens.append(token) + chunk.multipliers.append(weight) + position += 1 + continue + + emb_len = int(embedding.vec.shape[0]) + if len(chunk.tokens) + emb_len > self.chunk_length: + next_chunk() + + chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) + + chunk.tokens += [0] * emb_len + chunk.multipliers += [weight] * emb_len + position += embedding_length_in_tokens + + if len(chunk.tokens) > 0: + next_chunk() + + return chunks, token_count + + def process_texts(self, texts): + """ + Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum + length, in tokens, of all texts. + """ + token_count = 0 cache = {} - batch_multipliers = [] + batch_chunks = [] for line in texts: if line in cache: - remade_tokens, fixes, multipliers = cache[line] + chunks = cache[line] else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + chunks, current_token_count = self.tokenize_line(line) token_count = max(current_token_count, token_count) - cache[line] = (remade_tokens, fixes, multipliers) + cache[line] = chunks - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) + batch_chunks.append(chunks) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + return batch_chunks, token_count - def process_text_old(self, texts): - id_start = self.id_start - id_end = self.id_end - maxlen = self.wrapped.max_length # you get to stay at 77 - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] - token_count = 0 + def forward(self, texts): + """ + Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. + Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. + An example shape returned by this function can be: (2, 77, 768). + Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet + is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" + """ - cache = {} - batch_tokens = self.tokenize(texts) - batch_multipliers = [] - for tokens in batch_tokens: - tuple_tokens = tuple(tokens) + if opts.use_old_emphasis_implementation: + import modules.sd_hijack_clip_old + return modules.sd_hijack_clip_old.forward_old(self, texts) - if tuple_tokens in cache: - remade_tokens, fixes, multipliers = cache[tuple_tokens] - else: - fixes = [] - remade_tokens = [] - multipliers = [] - mult = 1.0 - - i = 0 - while i < len(tokens): - token = tokens[i] - - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - - mult_change = self.token_mults.get(token) if opts.enable_emphasis else None - if mult_change is not None: - mult *= mult_change - i += 1 - elif embedding is None: - remade_tokens.append(token) - multipliers.append(mult) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] - cache[tuple_tokens] = (remade_tokens, fixes, multipliers) - - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] - - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - - def forward(self, text): - use_old = opts.use_old_emphasis_implementation - if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) - else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - - self.hijack.comments += hijack_comments - - if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - - if use_old: - self.hijack.fixes = hijack_fixes - return self.process_tokens(remade_batch_tokens, batch_multipliers) - - z = None - i = 0 - while max(map(len, remade_batch_tokens)) != 0: - rem_tokens = [x[75:] for x in remade_batch_tokens] - rem_multipliers = [x[75:] for x in batch_multipliers] - - self.hijack.fixes = [] - for unfiltered in hijack_fixes: - fixes = [] - for fix in unfiltered: - if fix[0] == i: - fixes.append(fix[1]) - self.hijack.fixes.append(fixes) - - tokens = [] - multipliers = [] - for j in range(len(remade_batch_tokens)): - if len(remade_batch_tokens[j]) > 0: - tokens.append(remade_batch_tokens[j][:75]) - multipliers.append(batch_multipliers[j][:75]) - else: - tokens.append([self.id_end] * 75) - multipliers.append([1.0] * 75) - - z1 = self.process_tokens(tokens, multipliers) - z = z1 if z is None else torch.cat((z, z1), axis=-2) - - remade_batch_tokens = rem_tokens - batch_multipliers = rem_multipliers - i += 1 + batch_chunks, token_count = self.process_texts(texts) - return z + used_embeddings = {} + chunk_count = max([len(x) for x in batch_chunks]) - def process_tokens(self, remade_batch_tokens, batch_multipliers): - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.id_start] + x[:75] + [self.id_end] for x in remade_batch_tokens] - batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + zs = [] + for i in range(chunk_count): + batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] + + tokens = [x.tokens for x in batch_chunk] + multipliers = [x.multipliers for x in batch_chunk] + self.hijack.fixes = [x.fixes for x in batch_chunk] + for fixes in self.hijack.fixes: + for position, embedding in fixes: + used_embeddings[embedding.name] = embedding + + z = self.process_tokens(tokens, multipliers) + zs.append(z) + + if len(used_embeddings) > 0: + embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) + self.hijack.comments.append(f"Used embeddings: {embeddings_list}") + + return torch.hstack(zs) + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + """ + sends one single prompt chunk to be encoded by transformers neural network. + remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually + there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens. + Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier + corresponds to one token. + """ tokens = torch.asarray(remade_batch_tokens).to(devices.device) + # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones. if self.id_end != self.id_pad: for batch_pos in range(len(remade_batch_tokens)): index = remade_batch_tokens[batch_pos].index(self.id_end) @@ -239,8 +234,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): z = self.encode_with_transformers(tokens) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] - batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(devices.device) + batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() -- cgit v1.2.3 From 08066676a47b560235d4c085dd3cfcb470b80997 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 07:22:07 +0300 Subject: make it not break on empty inputs; thank you tarded, we are --- modules/sd_hijack_clip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index ac3020d7..16aef76a 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -147,7 +147,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk.multipliers += [weight] * emb_len position += embedding_length_in_tokens - if len(chunk.tokens) > 0: + if len(chunk.tokens) > 0 or len(chunks) == 0: next_chunk() return chunks, token_count -- cgit v1.2.3 From 1740c33547b62f692834c95914a2b295d51684c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 7 Jan 2023 07:48:44 +0300 Subject: more comments --- modules/sd_hijack_clip.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 16aef76a..5520c9b2 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -3,7 +3,7 @@ from collections import namedtuple import torch -from modules import prompt_parser, devices +from modules import prompt_parser, devices, sd_hijack from modules.shared import opts @@ -22,14 +22,24 @@ class PromptChunk: PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) -"""This is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt chunk""" +"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt +chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally +are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): + """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to + have unlimited prompt length and assign weights to tokens in prompt. + """ + def __init__(self, wrapped, hijack): super().__init__() + self.wrapped = wrapped - self.hijack = hijack + """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, + depending on model.""" + + self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.chunk_length = 75 def empty_chunk(self): @@ -55,7 +65,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; All python lists with tokens are assumed to have same length, usually 77. if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on - model - can be 768 and 1024 + model - can be 768 and 1024. + Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None). """ raise NotImplementedError @@ -113,7 +124,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): last_comma = len(chunk.tokens) # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack - # is a setting that specifies that is there is a comma nearby, the text after comma should be moved out of this chunk and into the next. + # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: break_location = last_comma + 1 -- cgit v1.2.3 From df3b31eb559ab9fabf7e513bdeddd5282c16f124 Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 7 Jan 2023 07:04:59 -0500 Subject: In-place operations can break gradient calculation --- modules/sd_hijack_clip.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 5520c9b2..852afc66 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() - z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() - z *= original_mean / new_mean + z = z * (original_mean / new_mean) return z -- cgit v1.2.3