From 345028099d893f8a66726cfd13627d8cc1bcc724 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 3 Sep 2022 12:08:45 +0300 Subject: split codebase into multiple files; to anyone this affects negatively: sorry --- modules/sd_hijack.py | 208 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 modules/sd_hijack.py (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py new file mode 100644 index 00000000..6ee92e77 --- /dev/null +++ b/modules/sd_hijack.py @@ -0,0 +1,208 @@ +import os +import sys +import traceback +import torch +import numpy as np + +from modules.shared import opts, device + + +class StableDiffusionModelHijack: + ids_lookup = {} + word_embeddings = {} + word_embeddings_checksums = {} + fixes = None + comments = [] + dir_mtime = None + + def load_textual_inversion_embeddings(self, dirname, model): + mt = os.path.getmtime(dirname) + if self.dir_mtime is not None and mt <= self.dir_mtime: + return + + self.dir_mtime = mt + self.ids_lookup.clear() + self.word_embeddings.clear() + + tokenizer = model.cond_stage_model.tokenizer + + def const_hash(a): + r = 0 + for v in a: + r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF + return r + + def process_file(path, filename): + name = os.path.splitext(filename)[0] + + data = torch.load(path) + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + self.word_embeddings[name] = emb.detach() + self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}' + + ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] + + first_id = ids[0] + if first_id not in self.ids_lookup: + self.ids_lookup[first_id] = [] + self.ids_lookup[first_id].append((ids, name)) + + for fn in os.listdir(dirname): + try: + process_file(os.path.join(dirname, fn), fn) + except Exception: + print(f"Error loading emedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + continue + + print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.") + + def hijack(self, m): + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) + m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + + +class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): + def __init__(self, wrapped, hijack): + super().__init__() + self.wrapped = wrapped + self.hijack = hijack + self.tokenizer = wrapped.tokenizer + self.max_length = wrapped.max_length + self.token_mults = {} + + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + for text, ident in tokens_with_parens: + mult = 1.0 + for c in text: + if c == '[': + mult /= 1.1 + if c == ']': + mult *= 1.1 + if c == '(': + mult *= 1.1 + if c == ')': + mult /= 1.1 + + if mult != 1.0: + self.token_mults[ident] = mult + + def forward(self, text): + self.hijack.fixes = [] + self.hijack.comments = [] + remade_batch_tokens = [] + id_start = self.wrapped.tokenizer.bos_token_id + id_end = self.wrapped.tokenizer.eos_token_id + maxlen = self.wrapped.max_length - 2 + used_custom_terms = [] + + cache = {} + batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] + batch_multipliers = [] + for tokens in batch_tokens: + tuple_tokens = tuple(tokens) + + if tuple_tokens in cache: + remade_tokens, fixes, multipliers = cache[tuple_tokens] + else: + fixes = [] + remade_tokens = [] + multipliers = [] + mult = 1.0 + + i = 0 + while i < len(tokens): + token = tokens[i] + + possible_matches = self.hijack.ids_lookup.get(token, None) + + mult_change = self.token_mults.get(token) if opts.enable_emphasis else None + if mult_change is not None: + mult *= mult_change + elif possible_matches is None: + remade_tokens.append(token) + multipliers.append(mult) + else: + found = False + for ids, word in possible_matches: + if tokens[i:i+len(ids)] == ids: + emb_len = int(self.hijack.word_embeddings[word].shape[0]) + fixes.append((len(remade_tokens), word)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + i += len(ids) - 1 + found = True + used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) + break + + if not found: + remade_tokens.append(token) + multipliers.append(mult) + + i += 1 + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + + self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + remade_batch_tokens.append(remade_tokens) + self.hijack.fixes.append(fixes) + batch_multipliers.append(multipliers) + + if len(used_custom_terms) > 0: + self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + tokens = torch.asarray(remade_batch_tokens).to(device) + outputs = self.wrapped.transformer(input_ids=tokens) + z = outputs.last_hidden_state + + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise + batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device) + original_mean = z.mean() + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + new_mean = z.mean() + z *= original_mean / new_mean + + return z + + +class EmbeddingsWithFixes(torch.nn.Module): + def __init__(self, wrapped, embeddings): + super().__init__() + self.wrapped = wrapped + self.embeddings = embeddings + + def forward(self, input_ids): + batch_fixes = self.embeddings.fixes + self.embeddings.fixes = None + + inputs_embeds = self.wrapped(input_ids) + + if batch_fixes is not None: + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, word in fixes: + emb = self.embeddings.word_embeddings[word] + emb_len = min(tensor.shape[0]-offset, emb.shape[0]) + tensor[offset:offset+emb_len] = self.embeddings.word_embeddings[word][0:emb_len] + + return inputs_embeds + + +model_hijack = StableDiffusionModelHijack() -- cgit v1.2.3