aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_clip.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_hijack_clip.py')
-rw-r--r--modules/sd_hijack_clip.py308
1 files changed, 308 insertions, 0 deletions
diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
new file mode 100644
index 00000000..852afc66
--- /dev/null
+++ b/modules/sd_hijack_clip.py
@@ -0,0 +1,308 @@
+import math
+from collections import namedtuple
+
+import torch
+
+from modules import prompt_parser, devices, sd_hijack
+from modules.shared import opts
+
+
+class PromptChunk:
+ """
+ This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt.
+ If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary.
+ Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token,
+ so just 75 tokens from prompt.
+ """
+
+ def __init__(self):
+ self.tokens = []
+ self.multipliers = []
+ self.fixes = []
+
+
+PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding'])
+"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt
+chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally
+are applied by sd_hijack.EmbeddingsWithFixes's forward function."""
+
+
+class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
+ """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to
+ have unlimited prompt length and assign weights to tokens in prompt.
+ """
+
+ def __init__(self, wrapped, hijack):
+ super().__init__()
+
+ self.wrapped = wrapped
+ """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation,
+ depending on model."""
+
+ self.hijack: sd_hijack.StableDiffusionModelHijack = hijack
+ self.chunk_length = 75
+
+ def empty_chunk(self):
+ """creates an empty PromptChunk and returns it"""
+
+ chunk = PromptChunk()
+ chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1)
+ chunk.multipliers = [1.0] * (self.chunk_length + 2)
+ return chunk
+
+ def get_target_prompt_token_count(self, token_count):
+ """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented"""
+
+ return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length
+
+ def tokenize(self, texts):
+ """Converts a batch of texts into a batch of token ids"""
+
+ raise NotImplementedError
+
+ def encode_with_transformers(self, tokens):
+ """
+ converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens;
+ All python lists with tokens are assumed to have same length, usually 77.
+ if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on
+ model - can be 768 and 1024.
+ Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None).
+ """
+
+ raise NotImplementedError
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through
+ transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned."""
+
+ raise NotImplementedError
+
+ def tokenize_line(self, line):
+ """
+ this transforms a single prompt into a list of PromptChunk objects - as many as needed to
+ represent the prompt.
+ Returns the list and the total number of tokens in the prompt.
+ """
+
+ if opts.enable_emphasis:
+ parsed = prompt_parser.parse_prompt_attention(line)
+ else:
+ parsed = [[line, 1.0]]
+
+ tokenized = self.tokenize([text for text, _ in parsed])
+
+ chunks = []
+ chunk = PromptChunk()
+ token_count = 0
+ last_comma = -1
+
+ def next_chunk():
+ """puts current chunk into the list of results and produces the next one - empty"""
+ nonlocal token_count
+ nonlocal last_comma
+ nonlocal chunk
+
+ token_count += len(chunk.tokens)
+ to_add = self.chunk_length - len(chunk.tokens)
+ if to_add > 0:
+ chunk.tokens += [self.id_end] * to_add
+ chunk.multipliers += [1.0] * to_add
+
+ chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end]
+ chunk.multipliers = [1.0] + chunk.multipliers + [1.0]
+
+ last_comma = -1
+ chunks.append(chunk)
+ chunk = PromptChunk()
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ position = 0
+ while position < len(tokens):
+ token = tokens[position]
+
+ if token == self.comma_token:
+ last_comma = len(chunk.tokens)
+
+ # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack
+ # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next.
+ elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack:
+ break_location = last_comma + 1
+
+ reloc_tokens = chunk.tokens[break_location:]
+ reloc_mults = chunk.multipliers[break_location:]
+
+ chunk.tokens = chunk.tokens[:break_location]
+ chunk.multipliers = chunk.multipliers[:break_location]
+
+ next_chunk()
+ chunk.tokens = reloc_tokens
+ chunk.multipliers = reloc_mults
+
+ if len(chunk.tokens) == self.chunk_length:
+ next_chunk()
+
+ embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position)
+ if embedding is None:
+ chunk.tokens.append(token)
+ chunk.multipliers.append(weight)
+ position += 1
+ continue
+
+ emb_len = int(embedding.vec.shape[0])
+ if len(chunk.tokens) + emb_len > self.chunk_length:
+ next_chunk()
+
+ chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding))
+
+ chunk.tokens += [0] * emb_len
+ chunk.multipliers += [weight] * emb_len
+ position += embedding_length_in_tokens
+
+ if len(chunk.tokens) > 0 or len(chunks) == 0:
+ next_chunk()
+
+ return chunks, token_count
+
+ def process_texts(self, texts):
+ """
+ Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum
+ length, in tokens, of all texts.
+ """
+
+ token_count = 0
+
+ cache = {}
+ batch_chunks = []
+ for line in texts:
+ if line in cache:
+ chunks = cache[line]
+ else:
+ chunks, current_token_count = self.tokenize_line(line)
+ token_count = max(current_token_count, token_count)
+
+ cache[line] = chunks
+
+ batch_chunks.append(chunks)
+
+ return batch_chunks, token_count
+
+ def forward(self, texts):
+ """
+ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.
+ Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will
+ be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024.
+ An example shape returned by this function can be: (2, 77, 768).
+ Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet
+ is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
+ """
+
+ if opts.use_old_emphasis_implementation:
+ import modules.sd_hijack_clip_old
+ return modules.sd_hijack_clip_old.forward_old(self, texts)
+
+ batch_chunks, token_count = self.process_texts(texts)
+
+ used_embeddings = {}
+ chunk_count = max([len(x) for x in batch_chunks])
+
+ zs = []
+ for i in range(chunk_count):
+ batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks]
+
+ tokens = [x.tokens for x in batch_chunk]
+ multipliers = [x.multipliers for x in batch_chunk]
+ self.hijack.fixes = [x.fixes for x in batch_chunk]
+
+ for fixes in self.hijack.fixes:
+ for position, embedding in fixes:
+ used_embeddings[embedding.name] = embedding
+
+ z = self.process_tokens(tokens, multipliers)
+ zs.append(z)
+
+ if len(used_embeddings) > 0:
+ embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()])
+ self.hijack.comments.append(f"Used embeddings: {embeddings_list}")
+
+ return torch.hstack(zs)
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ """
+ sends one single prompt chunk to be encoded by transformers neural network.
+ remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually
+ there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens.
+ Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier
+ corresponds to one token.
+ """
+ tokens = torch.asarray(remade_batch_tokens).to(devices.device)
+
+ # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones.
+ if self.id_end != self.id_pad:
+ for batch_pos in range(len(remade_batch_tokens)):
+ index = remade_batch_tokens[batch_pos].index(self.id_end)
+ tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
+
+ z = self.encode_with_transformers(tokens)
+
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
+ batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
+ original_mean = z.mean()
+ z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ new_mean = z.mean()
+ z = z * (original_mean / new_mean)
+
+ return z
+
+
+class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
+ def __init__(self, wrapped, hijack):
+ super().__init__(wrapped, hijack)
+ self.tokenizer = wrapped.tokenizer
+
+ vocab = self.tokenizer.get_vocab()
+
+ self.comma_token = vocab.get(',</w>', None)
+
+ self.token_mults = {}
+ tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ for text, ident in tokens_with_parens:
+ mult = 1.0
+ for c in text:
+ if c == '[':
+ mult /= 1.1
+ if c == ']':
+ mult *= 1.1
+ if c == '(':
+ mult *= 1.1
+ if c == ')':
+ mult /= 1.1
+
+ if mult != 1.0:
+ self.token_mults[ident] = mult
+
+ self.id_start = self.wrapped.tokenizer.bos_token_id
+ self.id_end = self.wrapped.tokenizer.eos_token_id
+ self.id_pad = self.id_end
+
+ def tokenize(self, texts):
+ tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
+
+ return tokenized
+
+ def encode_with_transformers(self, tokens):
+ outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
+
+ if opts.CLIP_stop_at_last_layers > 1:
+ z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
+ z = self.wrapped.transformer.text_model.final_layer_norm(z)
+ else:
+ z = outputs.last_hidden_state
+
+ return z
+
+ def encode_embedding_init_text(self, init_text, nvpt):
+ embedding_layer = self.wrapped.transformer.text_model.embeddings
+ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)
+
+ return embedded