diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-08-27 08:17:55 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-08-27 08:17:55 +0000 |
commit | 9597b265ec07e8ec6dab7487152459046585c1f9 (patch) | |
tree | 5eafd5fc597851b68b0beeeeba800916f7f91858 /webui.py | |
parent | a51bedfb5ae2e1adeb7406b183305b2ea7530eac (diff) | |
download | stable-diffusion-webui-gfx803-9597b265ec07e8ec6dab7487152459046585c1f9.tar.gz stable-diffusion-webui-gfx803-9597b265ec07e8ec6dab7487152459046585c1f9.tar.bz2 stable-diffusion-webui-gfx803-9597b265ec07e8ec6dab7487152459046585c1f9.zip |
implementation for attention using [] and ()
Diffstat (limited to 'webui.py')
-rw-r--r-- | webui.py | 79 |
1 files changed, 56 insertions, 23 deletions
@@ -433,15 +433,15 @@ if os.path.exists(cmd_opts.gfpgan_dir): print(traceback.format_exc(), file=sys.stderr)
-class TextInversionEmbeddings:
+class StableDiffuionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
- fixes = []
+ fixes = None
used_custom_terms = []
dir_mtime = None
- def load(self, dir, model):
+ def load_textual_inversion_embeddings(self, dir, model):
mt = os.path.getmtime(dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
@@ -469,6 +469,7 @@ class TextInversionEmbeddings: self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
+
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
@@ -497,6 +498,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.embeddings = embeddings
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
+ self.token_mults = {}
+
+ tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
+ for text, ident in tokens_with_parens:
+ mult = 1.0
+ for c in text:
+ if c == '[':
+ mult /= 1.1
+ if c == ']':
+ mult *= 1.1
+ if c == '(':
+ mult *= 1.1
+ if c == ')':
+ mult /= 1.1
+
+ if mult != 1.0:
+ self.token_mults[ident] = mult
def forward(self, text):
self.embeddings.fixes = []
@@ -508,14 +526,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
+ batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
- remade_tokens, fixes = cache[tuple_tokens]
+ remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
+ multipliers = []
+ mult = 1.0
i = 0
while i < len(tokens):
@@ -523,14 +544,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): possible_matches = self.embeddings.ids_lookup.get(token, None)
- if possible_matches is None:
+ mult_change = self.token_mults.get(token)
+ if mult_change is not None:
+ mult *= mult_change
+ elif possible_matches is None:
remade_tokens.append(token)
+ multipliers.append(mult)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
fixes.append((len(remade_tokens), word))
remade_tokens.append(777)
+ multipliers.append(mult)
i += len(ids) - 1
found = True
self.embeddings.used_custom_terms.append((word, self.embeddings.word_embeddings_checksums[word]))
@@ -538,19 +564,32 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if not found:
remade_tokens.append(token)
+ multipliers.append(mult)
i += 1
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
- cache[tuple_tokens] = (remade_tokens, fixes)
+ cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
+
+ multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
+ multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
self.embeddings.fixes.append(fixes)
+ batch_multipliers.append(multipliers)
tokens = torch.asarray(remade_batch_tokens).to(self.wrapped.device)
outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state
+
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
+ batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
+ original_mean = z.mean()
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ new_mean = z.mean()
+ z *= original_mean / new_mean
+
return z
@@ -562,22 +601,17 @@ class EmbeddingsWithFixes(nn.Module): def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
- self.embeddings.fixes = []
+ self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids)
- for fixes, tensor in zip(batch_fixes, inputs_embeds):
- for offset, word in fixes:
- tensor[offset] = self.embeddings.word_embeddings[word]
-
- return inputs_embeds
+ if batch_fixes is not None:
+ for fixes, tensor in zip(batch_fixes, inputs_embeds):
+ for offset, word in fixes:
+ tensor[offset] = self.embeddings.word_embeddings[word]
-def get_learned_conditioning_with_embeddings(model, prompts):
- if os.path.exists(cmd_opts.embeddings_dir):
- text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
-
- return model.get_learned_conditioning(prompts)
+ return inputs_embeds
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN, do_not_save_grid=False, extra_generation_params=None):
@@ -648,7 +682,7 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
if os.path.exists(cmd_opts.embeddings_dir):
- text_inversion_embeddings.load(cmd_opts.embeddings_dir, model)
+ model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
output_images = []
with torch.no_grad(), autocast("cuda"), model.ema_scope():
@@ -661,8 +695,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_index, uc = model.get_learned_conditioning(len(prompts) * [""])
c = model.get_learned_conditioning(prompts)
- if len(text_inversion_embeddings.used_custom_terms) > 0:
- comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in text_inversion_embeddings.used_custom_terms]))
+ if len(model_hijack.used_custom_terms) > 0:
+ comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in model_hijack.used_custom_terms]))
# we manually generate all input noises because each one should have a specific seed
x = create_random_tensors([opt_C, height // opt_f, width // opt_f], seeds=seeds)
@@ -1060,10 +1094,9 @@ model = load_model_from_config(config, cmd_opts.ckpt) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = (model if cmd_opts.no_half else model.half()).to(device)
-text_inversion_embeddings = TextInversionEmbeddings()
-if os.path.exists(cmd_opts.embeddings_dir):
- text_inversion_embeddings.hijack(model)
+model_hijack = StableDiffuionModelHijack()
+model_hijack.hijack(model)
demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces],
|