From 9597b265ec07e8ec6dab7487152459046585c1f9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 27 Aug 2022 11:17:55 +0300 Subject: implementation for attention using [] and () --- webui.py | 79 +++++++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 56 insertions(+), 23 deletions(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index b3375e98..a0fa23c4 100644 --- a/webui.py +++ b/webui.py @@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir): print(traceback.format_exc(), file=sys.stderr) -class TextInversionEmbeddings: +class StableDiffuionModelHijack: ids_lookup = {} word_embeddings = {} word_embeddings_checksums = {} - fixes = [] + fixes = None used_custom_terms = [] dir_mtime = None - def load(self, dir, model): + def load_textual_inversion_embeddings(self, dir, model): mt = os.path.getmtime(dir) if self.dir_mtime is not None and mt <= self.dir_mtime: return @@ -469,6 +469,7 @@ class TextInversionEmbeddings: self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}' ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] + first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] @@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.embeddings = embeddings self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length + self.token_mults = {} + + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + for text, ident in tokens_with_parens: + mult = 1.0 + for c in text: + if c == '[': + mult /= 1.1 + if c == ']': + mult *= 1.1 + if c == '(': + mult *= 1.1 + if c == ')': + mult /= 1.1 + + if mult != 1.0: + self.token_mults[ident] = mult def forward(self, text): self.embeddings.fixes = [] @@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): cache = {} batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] + batch_multipliers = [] for tokens in batch_tokens: tuple_tokens = tuple(tokens) if tuple_tokens in cache: - remade_tokens, fixes = cache[tuple_tokens] + remade_tokens, fixes, multipliers = cache[tuple_tokens] else: fixes = [] remade_tokens = [] + multipliers = [] + mult = 1.0 i = 0 while i < len(tokens): @@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): possible_matches = self.embeddings.ids_lookup.get(token, None) - if possible_matches is None: + mult_change = self.token_mults.get(token) + if mult_change is not None: + mult *= mult_change + elif possible_matches is None: remade_tokens.append(token) + multipliers.append(mult) else: found = False for ids, word in possible_matches: if tokens[i:i+len(ids)] == ids: fixes.append((len(remade_tokens), word)) remade_tokens.append(777) + multipliers.append(mult) i += len(ids) - 1 found = True self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word])) @@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if not found: remade_tokens.append(token) + multipliers.append(mult) i += 1 remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] - cache[tuple_tokens] = (remade_tokens, fixes) + cache[tuple_tokens] = (remade_tokens, fixes, multipliers) + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] remade_batch_tokens.append(remade_tokens) self.embeddings.fixes.append(fixes) + batch_multipliers.append(multipliers) tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device) outputs = self.wrapped.transformer(input_ids=tokens) z = outputs.last_hidden_state + + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise + batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device) + original_mean = z.mean() + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + new_mean = z.mean() + z *= original_mean / new_mean + return z @@ -562,22 +601,17 @@ class EmbeddingsWithFixes(nn.Module): def forward(self, input_ids): batch_fixes = self.embeddings.fixes - self.embeddings.fixes = [] + self.embeddings.fixes = None inputs_embeds = self.wrapped(input_ids) - for fixes, tensor in zip(batch_fixes, inputs_embeds): - for offset, word in fixes: - tensor[offset] = self.embeddings.word_embeddings[word] - - return inputs_embeds + if batch_fixes is not None: + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, word in fixes: + tensor[offset] = self.embeddings.word_embeddings[word] -def get_learned_conditioning_with_embeddings(model, prompts): - if os.path.exists(cmd_opts.embeddings_dir): - text_inversion_embeddings.load(cmd_opts.embeddings_dir, model) - - return model.get_learned_conditioning(prompts) + return inputs_embeds def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None): @@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments]) if os.path.exists(cmd_opts.embeddings_dir): - text_inversion_embeddings.load(cmd_opts.embeddings_dir, model) + model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model) output_images = [] with torch.no_grad(), autocast("cuda"), model.ema_scope(): @@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, uc = model.get_learned_conditioning(len(prompts) * [""]) c = model.get_learned_conditioning(prompts) - if len(text_inversion_embeddings.used_custom_terms) > 0: - comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms])) + if len(model_hijack.used_custom_terms) > 0: + comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms])) # we manually generate all input noises because each one should have a specific seed x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds) @@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = (model if cmd_opts.no_half else model.half()).to(device) -text_inversion_embeddings = TextInversionEmbeddings() -if os.path.exists(cmd_opts.embeddings_dir): - text_inversion_embeddings.hijack(model) +model_hijack = StableDiffuionModelHijack() +model_hijack.hijack(model) demo = gr.TabbedInterface( interface_list=[x[0] for x in interfaces], -- cgit v1.2.3