From c2d5b29040132c171bc4d77f1f63da972306f22c Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Thu, 29 Sep 2022 01:14:54 -0300 Subject: Move silu to sd_hijack --- modules/sd_hijack.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bfbd07f9..4bc58fa2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -12,6 +12,7 @@ from ldm.util import default from einops import rearrange import ldm.modules.attention import ldm.modules.diffusionmodules.model +from torch.nn.functional import silu # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion @@ -100,14 +101,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -245,11 +238,12 @@ class StableDiffusionModelHijack: m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model + ldm.modules.diffusionmodules.model.nonlinearity = silu + if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward def flatten(el): -- cgit v1.2.3 From 820f1dc96b1979d7e92170c161db281ee8bd988b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 15:03:39 +0300 Subject: initial support for training textual inversion --- modules/sd_hijack.py | 324 ++++++++------------------------------------------- 1 file changed, 51 insertions(+), 273 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fa7eaeb8..fd57e5c5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,244 +6,41 @@ import torch import numpy as np from torch import einsum -from modules import prompt_parser +import modules.textual_inversion.textual_inversion +from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts -from ldm.util import default -from einops import rearrange import ldm.modules.attention import ldm.modules.diffusionmodules.model +attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward +diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity +diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward -# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): - h = self.heads - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - del context, x +def apply_optimizations(): + if cmd_opts.opt_split_attention_v1: + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale +def undo_optimizations(): + ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward + ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward - s2 = s1.softmax(dim=-1) - del s1 - - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - - -# taken from https://github.com/Doggettx/stable-diffusion -def split_cross_attention_forward(self, x, context=None, mask=None): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - - del q, k, v - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - -def cross_attention_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q1.shape - - q2 = q1.reshape(b, c, h*w) - del q1 - - q = q2.permute(0, 2, 1) # b,hw,c - del q2 - - k = k1.reshape(b, c, h*w) # b,c,hw - del k1 - - h_ = torch.zeros_like(k, device=q.device) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c)**(-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) - del w2 - - # attend to values - v1 = v.reshape(b, c, h*w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 - - h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 - - h2 = h_.reshape(b, c, h, w) - del h_ - - h3 = self.proj_out(h2) - del h2 - - h3 += x - - return h3 class StableDiffusionModelHijack: - ids_lookup = {} - word_embeddings = {} - word_embeddings_checksums = {} fixes = None comments = [] - dir_mtime = None layers = None circular_enabled = False clip = None - def load_textual_inversion_embeddings(self, dirname, model): - mt = os.path.getmtime(dirname) - if self.dir_mtime is not None and mt <= self.dir_mtime: - return - - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() - - tokenizer = model.cond_stage_model.tokenizer - - def const_hash(a): - r = 0 - for v in a: - r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF - return r - - def process_file(path, filename): - name = os.path.splitext(filename)[0] - - data = torch.load(path, map_location="cpu") - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - self.word_embeddings[name] = emb.detach().to(device) - self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}' - - ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] - - first_id = ids[0] - if first_id not in self.ids_lookup: - self.ids_lookup[first_id] = [] - self.ids_lookup[first_id].append((ids, name)) - - for fn in os.listdir(dirname): - try: - fullfn = os.path.join(dirname, fn) - - if os.stat(fullfn).st_size == 0: - continue - - process_file(fullfn, fn) - except Exception: - print(f"Error loading emedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue - - print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): model_embeddings = m.cond_stage_model.transformer.text_model.embeddings @@ -253,12 +50,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model - if cmd_opts.opt_split_attention_v1: - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack - ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + apply_optimizations() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.hijack = hijack + self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length self.token_mults = {} @@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - if possible_matches is None: + if embedding is None: remade_tokens.append(token) multipliers.append(weight) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i + len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: mult *= mult_change - elif possible_matches is None: + i += 1 + elif embedding is None: remade_tokens.append(token) multipliers.append(mult) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i+len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(mult) - - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] @@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - self.hijack.fixes = hijack_fixes self.hijack.comments = hijack_comments @@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module): inputs_embeds = self.wrapped(input_ids) - if batch_fixes is not None: - for fixes, tensor in zip(batch_fixes, inputs_embeds): - for offset, word in fixes: - emb = self.embeddings.word_embeddings[word] - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len] + if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: + return inputs_embeds + + vecs = [] + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, embedding in fixes: + emb = embedding.vec + emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + + vecs.append(tensor) - return inputs_embeds + return torch.stack(vecs) def add_circular_option_to_conv_2d(): -- cgit v1.2.3 From 88ec0cf5571883d84abd09196652b3679e359f2e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 19:40:51 +0300 Subject: fix for incorrect embedding token length calculation (will break seeds that use embeddings, you're welcome!) add option to input initialization text for embeddings --- modules/sd_hijack.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fd57e5c5..3fa06242 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) if embedding is None: remade_tokens.append(token) @@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_tokens += [0] * emb_len multipliers += [weight] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) - i += emb_len + i += embedding_length_in_tokens if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: @@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_tokens += [0] * emb_len multipliers += [mult] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) - i += emb_len + i += embedding_length_in_tokens if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} -- cgit v1.2.3 From 2eb911b056ce6ff4434f673366782ed34f2b2f12 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:22:28 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a6fa890c..6221ed5a 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -20,12 +20,17 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): - ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward + if cmd_opts.opt_split_attention: + ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + elif not cmd_opts.disable_opt_xformers_attention: + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init + ldm.modules.attention.CrossAttention.attention_op = None + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From 35d6b231628d18d53d166c3a92fea1523e88d51e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:31:53 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6221ed5a..a006c0a3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -20,17 +20,16 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): + ldm.modules.diffusionmodules.model.nonlinearity = silu if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 if cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward elif not cmd_opts.disable_opt_xformers_attention: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init ldm.modules.attention.CrossAttention.attention_op = None - ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From 5303df24282ba06abb34a423f2967354d37d078e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 06:01:14 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a006c0a3..ddacb0ad 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -23,10 +23,10 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - if cmd_opts.opt_split_attention: + elif cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - elif not cmd_opts.disable_opt_xformers_attention: + elif not cmd_opts.disable_opt_xformers_attention and not cmd_opts.opt_split_attention: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init ldm.modules.attention.CrossAttention.attention_op = None -- cgit v1.2.3 From 5e3ff846c56dc8e1d5c76ea04a8f2f74d7da07fc Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 06:38:01 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ddacb0ad..cbdb9d3c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -26,7 +26,7 @@ def apply_optimizations(): elif cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - elif not cmd_opts.disable_opt_xformers_attention and not cmd_opts.opt_split_attention: + elif not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init ldm.modules.attention.CrossAttention.attention_op = None -- cgit v1.2.3 From f7c787eb7c295c27439f4fbdf78c26b8389560be Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 16:39:51 +0300 Subject: make it possible to use hypernetworks without opt split attention --- modules/sd_hijack.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a6fa890c..d68f89cc 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared +from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork from modules.shared import opts, device, cmd_opts import ldm.modules.attention @@ -20,6 +20,8 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): + undo_optimizations() + ldm.modules.diffusionmodules.model.nonlinearity = silu if cmd_opts.opt_split_attention_v1: @@ -30,7 +32,7 @@ def apply_optimizations(): def undo_optimizations(): - ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -- cgit v1.2.3 From 12c4d5c6b5bf9dd50d0601c36af4f99b65316d58 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 23:22:22 +0300 Subject: hypernetwork training mk1 --- modules/sd_hijack.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d68f89cc..ec8c9d4b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork +from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts import ldm.modules.attention @@ -32,6 +32,8 @@ def apply_optimizations(): def undo_optimizations(): + from modules.hypernetwork import hypernetwork + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -- cgit v1.2.3 From b70eaeb2005a5a9593119e7fd32b8072c2a208d5 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 04:10:35 +0300 Subject: delete broken and unnecessary aliases --- modules/sd_hijack.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cbdb9d3c..0e99c319 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -21,16 +21,14 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.opt_split_attention_v1: + if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif cmd_opts.opt_split_attention: ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - elif not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.attention.CrossAttention._maybe_init = sd_hijack_optimizations._maybe_init - ldm.modules.attention.CrossAttention.attention_op = None - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward def undo_optimizations(): -- cgit v1.2.3 From 91d66f5520df416db718103d460550ad495e952d Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 11:56:01 +0300 Subject: use new attnblock for xformers path --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 0e99c319..3da8c8ce 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -23,7 +23,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif cmd_opts.opt_split_attention: -- cgit v1.2.3 From 706d5944a075a6523ea7f00165d630efc085ca22 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 13:38:57 +0300 Subject: let user choose his own prompt token count limit --- modules/sd_hijack.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d68f89cc..340329c0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -18,7 +18,6 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward - def apply_optimizations(): undo_optimizations() @@ -83,7 +82,7 @@ class StableDiffusionModelHijack: layer.padding_mode = 'circular' if enable else 'zeros' def tokenize(self, text): - max_length = self.clip.max_length - 2 + max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, max_length @@ -94,7 +93,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.wrapped = wrapped self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer - self.max_length = wrapped.max_length self.token_mults = {} tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] @@ -116,7 +114,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length + maxlen = opts.max_prompt_tokens if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) @@ -191,7 +189,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length + maxlen = self.wrapped.max_length # you get to stay at 77 used_custom_terms = [] remade_batch_tokens = [] overflowing_words = [] @@ -268,8 +266,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76] + position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) + tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens) + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise -- cgit v1.2.3 From 4999eb2ef9b30e8c42ca7e4a94d4bbffe4d1f015 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 14:25:47 +0300 Subject: do not let user choose his own prompt token count limit --- modules/sd_hijack.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 340329c0..2c1332c9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -36,6 +36,13 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward +def get_target_prompt_token_count(token_count): + if token_count < 75: + return 75 + + return math.ceil(token_count / 10) * 10 + + class StableDiffusionModelHijack: fixes = None comments = [] @@ -84,7 +91,7 @@ class StableDiffusionModelHijack: def tokenize(self, text): max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) - return remade_batch_tokens[0], token_count, max_length + return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): @@ -114,7 +121,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = opts.max_prompt_tokens if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) @@ -146,19 +152,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): used_custom_terms.append((embedding.name, embedding.checksum())) i += embedding_length_in_tokens - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + prompt_target_length = get_target_prompt_token_count(token_count) + tokens_to_add = prompt_target_length - len(remade_tokens) + 1 - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add + multipliers = [1.0] + multipliers + [1.0] * tokens_to_add return remade_tokens, fixes, multipliers, token_count -- cgit v1.2.3 From 77f4237d1c3af1756e7dab2699e3dcebad5619d6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 15:25:59 +0300 Subject: fix bugs related to variable prompt lengths --- modules/sd_hijack.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2c1332c9..7e7fde0f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -89,7 +89,6 @@ class StableDiffusionModelHijack: layer.padding_mode = 'circular' if enable else 'zeros' def tokenize(self, text): - max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) @@ -174,7 +173,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if line in cache: remade_tokens, fixes, multipliers = cache[line] else: - remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + token_count = max(current_token_count, token_count) cache[line] = (remade_tokens, fixes, multipliers) @@ -265,15 +265,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76] + target_token_count = get_target_prompt_token_count(token_count) + 2 + + position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76] position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) - tokens = torch.asarray(remade_batch_tokens).to(device) + remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] + tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers = torch.asarray(batch_multipliers).to(device) + batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] + batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() -- cgit v1.2.3 From 5f85a74b00c0154bfd559dc67edfa7e30342b7c9 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 7 Oct 2022 17:48:34 -0400 Subject: fix bug where when using prompt composition, hijack_comments generated before the final AND will be dropped --- modules/sd_hijack.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7e7fde0f..ba808a39 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -88,6 +88,9 @@ class StableDiffusionModelHijack: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: layer.padding_mode = 'circular' if enable else 'zeros' + def clear_comments(self): + self.comments = [] + def tokenize(self, text): _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) @@ -260,7 +263,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) self.hijack.fixes = hijack_fixes - self.hijack.comments = hijack_comments + self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) -- cgit v1.2.3 From 26b459a3799c5cdf71ca8ed5315a99f69c69f02c Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:20:04 +0300 Subject: default to split attention if cuda is available and xformers is not --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3da8c8ce..04adcf03 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -21,12 +21,12 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip): + if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip or shared.xformers_available): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif cmd_opts.opt_split_attention: + elif cmd_opts.opt_split_attention or torch.cuda.is_available(): ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From 970de9ee6891ff586821d0d80dde01c2f6c681b3 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:29:43 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 04adcf03..5b30539f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -21,7 +21,7 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip or shared.xformers_available): + if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: -- cgit v1.2.3 From dc1117233ef8f9b25ff1ac40b158f20b70ba2fcb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:02:18 +0300 Subject: simplify xfrmers options: --xformers to enable and that's it --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5d93f7f6..91e98c16 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,7 +22,7 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available: + if cmd_opts.xformers and shared.xformers_available and not torch.version.hip: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: -- cgit v1.2.3 From 27032c47df9c07ac21dd5b89fa7dc247bb8705b6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:10:05 +0300 Subject: restore old opt_split_attention/disable_opt_split_attention logic --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 91e98c16..335a2bcf 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -27,7 +27,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif cmd_opts.opt_split_attention or torch.cuda.is_available(): + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From cfc33f99d47d1f45af15499e5965834089d11858 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:28:58 +0300 Subject: why did you do this --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 335a2bcf..ed271976 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -28,7 +28,7 @@ def apply_optimizations(): elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From 017b6b8744f0771e498656ec043e12d5cc6969a7 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:27:21 +0300 Subject: check for ampere --- modules/sd_hijack.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ed271976..5e266d5e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,9 +22,10 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.xformers and shared.xformers_available and not torch.version.hip: - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + if cmd_opts.xformers and shared.xformers_available and torch.version.cuda: + if torch.cuda.get_device_capability(shared.device) == (8, 6): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): -- cgit v1.2.3 From cc0258aea7b6605be3648900063cfa96ed7c5ffa Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:44:53 +0300 Subject: check for ampere without destroying the optimizations. again. --- modules/sd_hijack.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5e266d5e..a3e374f0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,10 +22,9 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.xformers and shared.xformers_available and torch.version.cuda: - if torch.cuda.get_device_capability(shared.device) == (8, 6): - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + if cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): -- cgit v1.2.3 From 3061cdb7b610d4ba7f1ea695d9d6364b591e5bc7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 19:22:15 +0300 Subject: add --force-enable-xformers option and also add messages to console regarding cross attention optimizations --- modules/sd_hijack.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a3e374f0..307cc67d 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,12 +22,16 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6): + + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)): + print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: + print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + print("Applying cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From 1371d7608b402d6f15c200ec2f5fde4579836a05 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 14:28:22 -0400 Subject: Added ability to ignore last n layers in FrozenCLIPEmbedder --- modules/sd_hijack.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 307cc67d..f12a9696 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -281,8 +281,15 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) - z = outputs.last_hidden_state + + tmp = -opts.CLIP_ignore_last_layers + if (opts.CLIP_ignore_last_layers == 0): + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) + z = outputs.last_hidden_state + else: + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] -- cgit v1.2.3 From e59c66c0088422b27f64b401ef42c242f836725a Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 16:32:05 -0400 Subject: Optimized code for Ignoring last CLIP layers --- modules/sd_hijack.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f12a9696..4a2d2153 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -282,14 +282,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - tmp = -opts.CLIP_ignore_last_layers - if (opts.CLIP_ignore_last_layers == 0): - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) - z = outputs.last_hidden_state - else: - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) - z = outputs.hidden_states[tmp] - z = self.wrapped.transformer.text_model.final_layer_norm(z) + tmp = -opts.CLIP_stop_at_last_layers + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] -- cgit v1.2.3 From ad3ae441081155dcd4fde805279e5082ca264695 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sun, 9 Oct 2022 04:32:40 -0400 Subject: Updated code for legibility --- modules/sd_hijack.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4a2d2153..7793d25b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -284,8 +284,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): tmp = -opts.CLIP_stop_at_last_layers outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) - z = outputs.hidden_states[tmp] - z = self.wrapped.transformer.text_model.final_layer_norm(z) + if tmp < -1: + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) + else: + z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] -- cgit v1.2.3 From 1824e9ee3ab4f94aee8908a62ea2569a01aeb3d7 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sun, 9 Oct 2022 14:15:43 -0400 Subject: Removed unnecessary tmp variable --- modules/sd_hijack.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7793d25b..437acce4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -282,10 +282,9 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - tmp = -opts.CLIP_stop_at_last_layers - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) - if tmp < -1: - z = outputs.hidden_states[tmp] + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers) + if opts.CLIP_stop_at_last_layers > 1: + z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] z = self.wrapped.transformer.text_model.final_layer_norm(z) else: z = outputs.last_hidden_state -- cgit v1.2.3 From b340439586d844e76782149ca1857c8de35773ec Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 10 Oct 2022 05:28:06 +0100 Subject: Unlimited Token Works Unlimited tokens actually work now. Works with textual inversion too. Replaces the previous not-so-much-working implementation. --- modules/sd_hijack.py | 69 ++++++++++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 23 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 437acce4..8d5c77d8 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -43,10 +43,7 @@ def undo_optimizations(): def get_target_prompt_token_count(token_count): - if token_count < 75: - return 75 - - return math.ceil(token_count / 10) * 10 + return math.ceil(max(token_count, 1) / 75) * 75 class StableDiffusionModelHijack: @@ -127,7 +124,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.token_mults[ident] = mult def tokenize_line(self, line, used_custom_terms, hijack_comments): - id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id if opts.enable_emphasis: @@ -154,7 +150,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): i += 1 else: emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) + iteration = len(remade_tokens) // 75 + fixes.append((iteration, (len(remade_tokens) % 75, embedding))) remade_tokens += [0] * emb_len multipliers += [weight] * emb_len used_custom_terms.append((embedding.name, embedding.checksum())) @@ -162,10 +159,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): token_count = len(remade_tokens) prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) + 1 + tokens_to_add = prompt_target_length - len(remade_tokens) - remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add - multipliers = [1.0] + multipliers + [1.0] * tokens_to_add + remade_tokens = remade_tokens + [id_end] * tokens_to_add + multipliers = multipliers + [1.0] * tokens_to_add return remade_tokens, fixes, multipliers, token_count @@ -260,29 +257,55 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): hijack_fixes.append(fixes) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - + def forward(self, text): - - if opts.use_old_emphasis_implementation: + use_old = opts.use_old_emphasis_implementation + if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - self.hijack.fixes = hijack_fixes self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + + if use_old: + self.hijack.fixes = hijack_fixes + return self.process_tokens(remade_batch_tokens, batch_multipliers) + + z = None + i = 0 + while max(map(len, remade_batch_tokens)) != 0: + rem_tokens = [x[75:] for x in remade_batch_tokens] + rem_multipliers = [x[75:] for x in batch_multipliers] + + self.hijack.fixes = [] + for unfiltered in hijack_fixes: + fixes = [] + for fix in unfiltered: + if fix[0] == i: + fixes.append(fix[1]) + self.hijack.fixes.append(fixes) + + z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers]) + z = z1 if z is None else torch.cat((z, z1), axis=-2) + + remade_batch_tokens = rem_tokens + batch_multipliers = rem_multipliers + i += 1 + + return z + + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] + batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] + + tokens = torch.asarray(remade_batch_tokens).to(device) + outputs = self.wrapped.transformer(input_ids=tokens) - target_token_count = get_target_prompt_token_count(token_count) + 2 - - position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76] - position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) - - remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] - tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers) if opts.CLIP_stop_at_last_layers > 1: z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] z = self.wrapped.transformer.text_model.final_layer_norm(z) @@ -290,7 +313,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] + batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) original_mean = z.mean() z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) -- cgit v1.2.3 From 460bbae58726c177beddfcddf351f27e205d3fb2 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 10 Oct 2022 16:09:06 +0100 Subject: Pad beginning of textual inversion embedding --- modules/sd_hijack.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 8d5c77d8..3a60cd63 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -151,6 +151,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: emb_len = int(embedding.vec.shape[0]) iteration = len(remade_tokens) // 75 + if (len(remade_tokens) + emb_len) // 75 != iteration: + rem = (75 * (iteration + 1) - len(remade_tokens)) + remade_tokens += [id_end] * rem + multipliers += [1.0] * rem + iteration += 1 fixes.append((iteration, (len(remade_tokens) % 75, embedding))) remade_tokens += [0] * emb_len multipliers += [weight] * emb_len -- cgit v1.2.3 From d5c14365fd468dbf89fa12a68bea5b217077273c Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Mon, 10 Oct 2022 16:13:47 +0100 Subject: Add back in output hidden states parameter --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3a60cd63..3edc0e9d 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -309,7 +309,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens) + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) if opts.CLIP_stop_at_last_layers > 1: z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] -- cgit v1.2.3 From 623251ce2b8d152e242011f62984a8247a14a389 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 10 Oct 2022 17:45:38 +0300 Subject: allow pascal onwards --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3edc0e9d..827bf304 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -23,7 +23,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward -- cgit v1.2.3 From 5e2627a1a63e4c9f87e6e604ecc24e9936f149de Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Tue, 11 Oct 2022 07:55:28 +0100 Subject: Comma backtrack padding (#2192) Comma backtrack padding --- modules/sd_hijack.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 827bf304..aa4d2cbc 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -107,6 +107,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.tokenizer = wrapped.tokenizer self.token_mults = {} + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 @@ -136,6 +138,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): fixes = [] remade_tokens = [] multipliers = [] + last_comma = -1 for tokens, (text, weight) in zip(tokenized, parsed): i = 0 @@ -144,6 +147,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + if token == self.comma_token: + last_comma = len(remade_tokens) + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + last_comma += 1 + reloc_tokens = remade_tokens[last_comma:] + reloc_mults = multipliers[last_comma:] + + remade_tokens = remade_tokens[:last_comma] + length = len(remade_tokens) + + rem = int(math.ceil(length / 75)) * 75 - length + remade_tokens += [id_end] * rem + reloc_tokens + multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + if embedding is None: remade_tokens.append(token) multipliers.append(weight) @@ -284,7 +301,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while max(map(len, remade_batch_tokens)) != 0: rem_tokens = [x[75:] for x in remade_batch_tokens] rem_multipliers = [x[75:] for x in batch_multipliers] - + self.hijack.fixes = [] for unfiltered in hijack_fixes: fixes = [] -- cgit v1.2.3 From 873efeed49bb5197a42da18272115b326c5d68f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 15:51:22 +0300 Subject: rename hypernetwork dir to hypernetworks to prevent clash with an old filename that people who use zip instead of git clone will have --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f873049a..f07ec041 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,7 +37,7 @@ def apply_optimizations(): def undo_optimizations(): - from modules.hypernetwork import hypernetwork + from modules.hypernetworks import hypernetwork ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity -- cgit v1.2.3 From c0484f1b986ce7acb0e3596f6089a191279f5442 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 10 Oct 2022 22:48:54 -0400 Subject: Add cross-attention optimization from InvokeAI * Add cross-attention optimization from InvokeAI (~30% speed improvement on MPS) * Add command line option for it * Make it default when CUDA is unavailable --- modules/sd_hijack.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f07ec041..5a1b167f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -30,8 +30,11 @@ def apply_optimizations(): elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + print("Applying cross attention optimization (InvokeAI).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - print("Applying cross attention optimization.") + print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward -- cgit v1.2.3 From 98fd5cde72d5bda1620ab78416c7828fdc3dc10b Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 10 Oct 2022 23:55:48 -0400 Subject: Add check for psutil --- modules/sd_hijack.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5a1b167f..ac70f876 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -10,6 +10,7 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts +from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -31,8 +32,13 @@ def apply_optimizations(): print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): - print("Applying cross attention optimization (InvokeAI).") - ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI + if not invokeAI_mps_available and shared.device.type == 'mps': + print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") + print("Applying v1 cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + else: + print("Applying cross attention optimization (InvokeAI).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward -- cgit v1.2.3 From 80f3cf2bb2ce3f00d801cae2c3a8c20a8d4167d8 Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Tue, 11 Oct 2022 19:48:53 +0100 Subject: Account when lines are mismatched --- modules/sd_hijack.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index ac70f876..2753d4fa 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -321,7 +321,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): fixes.append(fix[1]) self.hijack.fixes.append(fixes) - z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers]) + tokens = [] + multipliers = [] + for i in range(len(remade_batch_tokens)): + if len(remade_batch_tokens[i]) > 0: + tokens.append(remade_batch_tokens[i][:75]) + multipliers.append(batch_multipliers[i][:75]) + else: + tokens.append([self.wrapped.tokenizer.eos_token_id] * 75) + multipliers.append([1.0] * 75) + + z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) remade_batch_tokens = rem_tokens -- cgit v1.2.3 From 429442f4a6aab7301efb89d27bef524fe827e81a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 13:38:03 +0300 Subject: fix iterator bug for #2295 --- modules/sd_hijack.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2753d4fa..c81722a0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -323,10 +323,10 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): tokens = [] multipliers = [] - for i in range(len(remade_batch_tokens)): - if len(remade_batch_tokens[i]) > 0: - tokens.append(remade_batch_tokens[i][:75]) - multipliers.append(batch_multipliers[i][:75]) + for j in range(len(remade_batch_tokens)): + if len(remade_batch_tokens[j]) > 0: + tokens.append(remade_batch_tokens[j][:75]) + multipliers.append(batch_multipliers[j][:75]) else: tokens.append([self.wrapped.tokenizer.eos_token_id] * 75) multipliers.append([1.0] * 75) -- cgit v1.2.3 From bb57f30c2de46cfca5419ad01738a41705f96cc3 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Fri, 14 Oct 2022 10:56:41 +0200 Subject: init --- modules/sd_hijack.py | 80 ++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c81722a0..6d5196fe 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,11 +9,14 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared -from modules.shared import opts, device, cmd_opts +from modules.shared import opts, device, cmd_opts, aesthetic_embeddings from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model +from transformers import CLIPVisionModel, CLIPModel +import torch.optim as optim +import copy attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity @@ -109,13 +112,29 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) +def slerp(low, high, val): + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped + self.clipModel = CLIPModel.from_pretrained( + self.wrapped.transformer.name_or_path + ) + del self.clipModel.vision_model self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer + # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() + self.image_embs_name = None + self.image_embs = None + self.load_image_embs(None) + self.token_mults = {} self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] @@ -136,6 +155,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult + def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None, + aesthetic_slerp=True): + self.slerp = aesthetic_slerp + self.aesthetic_lr = aesthetic_lr + self.aesthetic_weight = aesthetic_weight + self.aesthetic_steps = aesthetic_steps + self.load_image_embs(image_embs_name) + + def load_image_embs(self, image_embs_name): + if image_embs_name is None or len(image_embs_name) == 0: + image_embs_name = None + if image_embs_name is not None and self.image_embs_name != image_embs_name: + self.image_embs_name = image_embs_name + self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device) + self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) + self.image_embs.requires_grad_(False) + def tokenize_line(self, line, used_custom_terms, hijack_comments): id_end = self.wrapped.tokenizer.eos_token_id @@ -333,7 +369,47 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - + + if len(text[ + 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [ + [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in + remade_batch_tokens] + + tokens = torch.asarray(remade_batch_tokens).to(device) + with torch.enable_grad(): + model = copy.deepcopy(self.clipModel).to(device) + model.requires_grad_(True) + + # We optimize the model to maximize the similarity + optimizer = optim.Adam( + model.text_model.parameters(), lr=self.aesthetic_lr + ) + + for i in range(self.aesthetic_steps): + text_embs = model.get_text_features(input_ids=tokens) + text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) + sim = text_embs @ self.image_embs.T + loss = -sim + optimizer.zero_grad() + loss.mean().backward() + optimizer.step() + + zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + if opts.CLIP_stop_at_last_layers > 1: + zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] + zn = model.text_model.final_layer_norm(zn) + else: + zn = zn.last_hidden_state + model.cpu() + del model + + if self.slerp: + z = slerp(z, zn, self.aesthetic_weight) + else: + z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight + remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 -- cgit v1.2.3 From 37d7ffb415cd8c69b3c0bb5f61844dde0b169f78 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 15:59:37 +0200 Subject: fix to tokens lenght, addend embs generator, add new features to edit the embedding before the generation using text --- modules/sd_hijack.py | 111 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 73 insertions(+), 38 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6d5196fe..192883b2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -14,7 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model -from transformers import CLIPVisionModel, CLIPModel +from tqdm import trange +from transformers import CLIPVisionModel, CLIPModel, CLIPTokenizer import torch.optim as optim import copy @@ -22,21 +23,25 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward + def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and ( + 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + elif not cmd_opts.disable_opt_split_attention and ( + cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): if not invokeAI_mps_available and shared.device.type == 'mps': - print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") + print( + "The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 else: @@ -112,14 +117,16 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) + def slerp(low, high, val): - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm*high_norm).sum(1)) + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() @@ -128,6 +135,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.wrapped.transformer.name_or_path ) del self.clipModel.vision_model + self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path) self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() @@ -139,7 +147,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if + '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 for c in text: @@ -155,8 +164,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None, - aesthetic_slerp=True): + def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, + aesthetic_slerp=True, aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False): + self.aesthetic_imgs_text = aesthetic_imgs_text + self.aesthetic_slerp_angle = aesthetic_slerp_angle + self.aesthetic_text_negative = aesthetic_text_negative self.slerp = aesthetic_slerp self.aesthetic_lr = aesthetic_lr self.aesthetic_weight = aesthetic_weight @@ -180,7 +194,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: parsed = [[line, 1.0]] - tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] + tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)[ + "input_ids"] fixes = [] remade_tokens = [] @@ -196,18 +211,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if token == self.comma_token: last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), + 1) % 75 == 0 and last_comma != -1 and len( + remade_tokens) - last_comma <= opts.comma_padding_backtrack: last_comma += 1 reloc_tokens = remade_tokens[last_comma:] reloc_mults = multipliers[last_comma:] remade_tokens = remade_tokens[:last_comma] length = len(remade_tokens) - + rem = int(math.ceil(length / 75)) * 75 - length remade_tokens += [id_end] * rem + reloc_tokens multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults - + if embedding is None: remade_tokens.append(token) multipliers.append(weight) @@ -248,7 +265,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if line in cache: remade_tokens, fixes, multipliers = cache[line] else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, + hijack_comments) token_count = max(current_token_count, token_count) cache[line] = (remade_tokens, fixes, multipliers) @@ -259,7 +277,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -289,7 +306,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, + i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: @@ -312,11 +330,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ovf = remade_tokens[maxlen - 2:] overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + hijack_comments.append( + f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) @@ -326,23 +345,26 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): hijack_fixes.append(fixes) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - + def forward(self, text): use_old = opts.use_old_emphasis_implementation if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old( + text) else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text( + text) self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - + self.hijack.comments.append( + "Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + if use_old: self.hijack.fixes = hijack_fixes return self.process_tokens(remade_batch_tokens, batch_multipliers) - + z = None i = 0 while max(map(len, remade_batch_tokens)) != 0: @@ -356,7 +378,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if fix[0] == i: fixes.append(fix[1]) self.hijack.fixes.append(fixes) - + tokens = [] multipliers = [] for j in range(len(remade_batch_tokens)): @@ -378,19 +400,30 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens).to(device) + + model = copy.deepcopy(self.clipModel).to(device) + model.requires_grad_(True) + if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: + text_embs_2 = model.get_text_features( + **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) + if self.aesthetic_text_negative: + text_embs_2 = self.image_embs - text_embs_2 + text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) + img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) + else: + img_embs = self.image_embs + with torch.enable_grad(): - model = copy.deepcopy(self.clipModel).to(device) - model.requires_grad_(True) # We optimize the model to maximize the similarity optimizer = optim.Adam( model.text_model.parameters(), lr=self.aesthetic_lr ) - for i in range(self.aesthetic_steps): + for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"): text_embs = model.get_text_features(input_ids=tokens) text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ self.image_embs.T + sim = text_embs @ img_embs.T loss = -sim optimizer.zero_grad() loss.mean().backward() @@ -405,6 +438,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): model.cpu() del model + zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1) if self.slerp: z = slerp(z, zn, self.aesthetic_weight) else: @@ -413,15 +447,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 - + return z - - + def process_tokens(self, remade_batch_tokens, batch_multipliers): if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] + remade_batch_tokens = [ + [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in + remade_batch_tokens] batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] - + tokens = torch.asarray(remade_batch_tokens).to(device) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) @@ -461,8 +496,8 @@ class EmbeddingsWithFixes(torch.nn.Module): for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: emb = embedding.vec - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) vecs.append(tensor) -- cgit v1.2.3 From 529afbf4d70165a0dfd19eb9c2ec22416b794a1d Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 15 Oct 2022 19:19:54 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c81722a0..984b35c4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -24,7 +24,7 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward -- cgit v1.2.3 From 9325c85f780c569d1823e422eaf51b2e497e0d3e Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 00:23:47 +0200 Subject: fixed dropbox update --- modules/sd_hijack.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 192883b2..491312b4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,7 +9,7 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared -from modules.shared import opts, device, cmd_opts, aesthetic_embeddings +from modules.shared import opts, device, cmd_opts from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention @@ -182,7 +182,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): image_embs_name = None if image_embs_name is not None and self.image_embs_name != image_embs_name: self.image_embs_name = image_embs_name - self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device) + self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) self.image_embs.requires_grad_(False) -- cgit v1.2.3 From 523140d7805c644700009b8a2483ff4eb4a22304 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 10:23:30 +0200 Subject: ui fix --- modules/sd_hijack.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 01fcb78f..2de2eed5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -392,8 +392,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - if len(text[ - 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: + if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: if not opts.use_old_emphasis_implementation: remade_batch_tokens = [ [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in -- cgit v1.2.3 From e4f8b5f00dd33b7547cc6b76fbed26bb83b37a64 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 10:28:21 +0200 Subject: ui fix --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2de2eed5..5d0590af 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -178,7 +178,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.load_image_embs(image_embs_name) def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0: + if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": image_embs_name = None if image_embs_name is not None and self.image_embs_name != image_embs_name: self.image_embs_name = image_embs_name -- cgit v1.2.3 From 9324cdaa3199d65c182858785dd1eca42b192b8e Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 17:53:56 +0200 Subject: ui fix, re organization of the code --- modules/sd_hijack.py | 102 +++------------------------------------------------ 1 file changed, 5 insertions(+), 97 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5d0590af..227e7670 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -29,8 +29,8 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and ( + 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward @@ -118,33 +118,14 @@ class StableDiffusionModelHijack: return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) -def slerp(low, high, val): - low_norm = low / torch.norm(low, dim=1, keepdim=True) - high_norm = high / torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm * high_norm).sum(1)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res - - class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.clipModel = CLIPModel.from_pretrained( - self.wrapped.transformer.name_or_path - ) - del self.clipModel.vision_model - self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path) - self.hijack: StableDiffusionModelHijack = hijack - self.tokenizer = wrapped.tokenizer - # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() - self.image_embs_name = None - self.image_embs = None - self.load_image_embs(None) self.token_mults = {} - + self.hijack: StableDiffusionModelHijack = hijack + self.tokenizer = wrapped.tokenizer self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if @@ -164,28 +145,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, - aesthetic_slerp=True, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False): - self.aesthetic_imgs_text = aesthetic_imgs_text - self.aesthetic_slerp_angle = aesthetic_slerp_angle - self.aesthetic_text_negative = aesthetic_text_negative - self.slerp = aesthetic_slerp - self.aesthetic_lr = aesthetic_lr - self.aesthetic_weight = aesthetic_weight - self.aesthetic_steps = aesthetic_steps - self.load_image_embs(image_embs_name) - - def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": - image_embs_name = None - if image_embs_name is not None and self.image_embs_name != image_embs_name: - self.image_embs_name = image_embs_name - self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) - self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) - self.image_embs.requires_grad_(False) - def tokenize_line(self, line, used_custom_terms, hijack_comments): id_end = self.wrapped.tokenizer.eos_token_id @@ -391,58 +350,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - - if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [ - [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in - remade_batch_tokens] - - tokens = torch.asarray(remade_batch_tokens).to(device) - - model = copy.deepcopy(self.clipModel).to(device) - model.requires_grad_(True) - if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: - text_embs_2 = model.get_text_features( - **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) - if self.aesthetic_text_negative: - text_embs_2 = self.image_embs - text_embs_2 - text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) - img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) - else: - img_embs = self.image_embs - - with torch.enable_grad(): - - # We optimize the model to maximize the similarity - optimizer = optim.Adam( - model.text_model.parameters(), lr=self.aesthetic_lr - ) - - for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"): - text_embs = model.get_text_features(input_ids=tokens) - text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ img_embs.T - loss = -sim - optimizer.zero_grad() - loss.mean().backward() - optimizer.step() - - zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - if opts.CLIP_stop_at_last_layers > 1: - zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] - zn = model.text_model.final_layer_norm(zn) - else: - zn = zn.last_hidden_state - model.cpu() - del model - - zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1) - if self.slerp: - z = slerp(z, zn, self.aesthetic_weight) - else: - z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight - + z = shared.aesthetic_clip(z, remade_batch_tokens) remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 -- cgit v1.2.3 From 786ed499226177d71e937e0342bcb9d3b1ff260f Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 17 Oct 2022 19:48:39 +0300 Subject: use legacy attnblock --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 984b35c4..2407a461 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -27,7 +27,7 @@ def apply_optimizations(): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 -- cgit v1.2.3 From 73b5dbf72a93b64445551c74a4c0dc924986081d Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Mon, 17 Oct 2022 22:19:18 +0300 Subject: Update sd_hijack.py --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2407a461..984b35c4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -27,7 +27,7 @@ def apply_optimizations(): if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward - ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 -- cgit v1.2.3 From 9286fe53de2eef91f13cc3ad5938ddf67ecc8413 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 16:38:06 +0300 Subject: make aestetic embedding ciompatible with prompts longer than 75 tokens --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 36198a3c..1f8587d1 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -332,8 +332,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers.append([1.0] * 75) z1 = self.process_tokens(tokens, multipliers) + z1 = shared.aesthetic_clip(z1, remade_batch_tokens) z = z1 if z is None else torch.cat((z, z1), axis=-2) - z = shared.aesthetic_clip(z, remade_batch_tokens) remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers -- cgit v1.2.3 From 2b91251637078e04472c91a06a8d9c4db9c1dcf0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 12:23:45 +0300 Subject: removed aesthetic gradients as built-in added support for extensions --- modules/sd_hijack.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 1f8587d1..0f10828e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -332,7 +332,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers.append([1.0] * 75) z1 = self.process_tokens(tokens, multipliers) - z1 = shared.aesthetic_clip(z1, remade_batch_tokens) z = z1 if z is None else torch.cat((z, z1), axis=-2) remade_batch_tokens = rem_tokens -- cgit v1.2.3 From af758e97fa2c4c853042f121af4e974be01e6696 Mon Sep 17 00:00:00 2001 From: Jairo Correa Date: Tue, 1 Nov 2022 04:01:49 -0300 Subject: Unload sd_model before loading the other --- modules/sd_hijack.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 0f10828e..bc49d235 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -94,6 +94,10 @@ class StableDiffusionModelHijack: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + self.layers = None + self.circular_enabled = False + self.clip = None + def apply_circular(self, enable): if self.circular_enabled == enable: return -- cgit v1.2.3 From 7ba3923d5b494b7756d0b12f33acb3716d830b9a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 11 Nov 2022 18:20:18 +0300 Subject: move DDIM/PLMS fix for OSX out of the file with inpainting code. --- modules/sd_hijack.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bc49d235..75b2d22d 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -14,6 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity @@ -406,3 +408,24 @@ def add_circular_option_to_conv_2d(): model_hijack = StableDiffusionModelHijack() + + +def register_buffer(self, name, attr): + """ + Fix register buffer bug for Mac OS. + """ + + if type(attr) == torch.Tensor: + if attr.device != devices.device: + + # would this not break cuda when torch adds has_mps() to main version? + if getattr(torch, 'has_mps', False): + attr = attr.to(device="mps", dtype=torch.float32) + else: + attr = attr.to(devices.device) + + setattr(self, name, attr) + + +ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer +ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer -- cgit v1.2.3 From c62d17aee36b5f4ca24f9cfa7bf6d7aca0c923f8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 12 Nov 2022 10:00:22 +0300 Subject: use the new devices.has_mps() function in register_buffer for DDIM/PLMS fix for OSX --- modules/sd_hijack.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 75b2d22d..97979d05 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -418,8 +418,7 @@ def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != devices.device: - # would this not break cuda when torch adds has_mps() to main version? - if getattr(torch, 'has_mps', False): + if devices.has_mps(): attr = attr.to(device="mps", dtype=torch.float32) else: attr = attr.to(devices.device) -- cgit v1.2.3 From 17e44328204a09653bb89eea18b7b489cc118703 Mon Sep 17 00:00:00 2001 From: killfrenzy96 Date: Fri, 18 Nov 2022 21:22:55 +1100 Subject: cleanly undo circular hijack #4818 --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 97979d05..eaedac13 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -96,8 +96,8 @@ class StableDiffusionModelHijack: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + self.apply_circular(False) self.layers = None - self.circular_enabled = False self.clip = None def apply_circular(self, enable): -- cgit v1.2.3 From bd68e35de3b7cf7547ed97d8bdf60147402133cc Mon Sep 17 00:00:00 2001 From: flamelaw Date: Sun, 20 Nov 2022 12:35:26 +0900 Subject: Gradient accumulation, autocast fix, new latent sampling method, etc --- modules/sd_hijack.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..29c8b561 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -8,7 +8,7 @@ from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared +from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.shared import opts, device, cmd_opts from modules.sd_hijack_optimizations import invokeAI_mps_available @@ -59,6 +59,10 @@ def undo_optimizations(): def get_target_prompt_token_count(token_count): return math.ceil(max(token_count, 1) / 75) * 75 +def fix_checkpoint(): + ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward + ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward + ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward class StableDiffusionModelHijack: fixes = None @@ -78,6 +82,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model apply_optimizations() + fix_checkpoint() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -303,7 +308,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - + self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: -- cgit v1.2.3 From adb6cb7619989cbc7a271cc6c2ae27bb936c43d9 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Wed, 23 Nov 2022 18:11:24 +0800 Subject: Patch UNet Forward to support resolutions that are not multiples of 64 Also modifed the UI to no longer step in 64 --- modules/sd_hijack.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..6141f705 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -16,6 +16,7 @@ import ldm.modules.attention import ldm.modules.diffusionmodules.model import ldm.models.diffusion.ddim import ldm.models.diffusion.plms +import ldm.modules.diffusionmodules.openaimodel attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity @@ -26,6 +27,7 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu + ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") -- cgit v1.2.3 From ce6911158b5b2f9cf79b405a1f368f875492044d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 26 Nov 2022 16:10:46 +0300 Subject: Add support Stable Diffusion 2.0 --- modules/sd_hijack.py | 297 +++++---------------------------------------------- 1 file changed, 28 insertions(+), 269 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..d5243fd3 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,18 +9,29 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared -from modules.shared import opts, device, cmd_opts +from modules.shared import cmd_opts +from modules import sd_hijack_clip, sd_hijack_open_clip + from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model import ldm.models.diffusion.ddim import ldm.models.diffusion.plms +import ldm.modules.encoders.modules attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +# new memory efficient cross attention blocks do not support hypernets and we already +# have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention +ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention +ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention + +# silence new console spam from SD2 +ldm.modules.attention.print = lambda *args: None +ldm.modules.diffusionmodules.model.print = lambda *args: None def apply_optimizations(): undo_optimizations() @@ -49,16 +60,11 @@ def apply_optimizations(): def undo_optimizations(): - from modules.hypernetworks import hypernetwork - - ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward # this stops hypernets from working ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -def get_target_prompt_token_count(token_count): - return math.ceil(max(token_count, 1) / 75) * 75 - class StableDiffusionModelHijack: fixes = None @@ -70,10 +76,13 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings - - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder: + m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) + m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model @@ -89,12 +98,15 @@ class StableDiffusionModelHijack: self.layers = flatten(m) def undo_hijack(self, m): - if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords: + if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings - if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: - model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: + model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + elif type(m.cond_stage_model) == sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords: + m.cond_stage_model.wrapped.model.token_embedding = m.cond_stage_model.wrapped.model.token_embedding.wrapped + m.cond_stage_model = m.cond_stage_model.wrapped self.apply_circular(False) self.layers = None @@ -114,261 +126,8 @@ class StableDiffusionModelHijack: def tokenize(self, text): _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) - return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) - - -class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): - def __init__(self, wrapped, hijack): - super().__init__() - self.wrapped = wrapped - self.hijack: StableDiffusionModelHijack = hijack - self.tokenizer = wrapped.tokenizer - self.token_mults = {} - - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - - tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] - for text, ident in tokens_with_parens: - mult = 1.0 - for c in text: - if c == '[': - mult /= 1.1 - if c == ']': - mult *= 1.1 - if c == '(': - mult *= 1.1 - if c == ')': - mult /= 1.1 - - if mult != 1.0: - self.token_mults[ident] = mult - - def tokenize_line(self, line, used_custom_terms, hijack_comments): - id_end = self.wrapped.tokenizer.eos_token_id - - if opts.enable_emphasis: - parsed = prompt_parser.parse_prompt_attention(line) - else: - parsed = [[line, 1.0]] - - tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] - - fixes = [] - remade_tokens = [] - multipliers = [] - last_comma = -1 - - for tokens, (text, weight) in zip(tokenized, parsed): - i = 0 - while i < len(tokens): - token = tokens[i] - - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - - if token == self.comma_token: - last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: - last_comma += 1 - reloc_tokens = remade_tokens[last_comma:] - reloc_mults = multipliers[last_comma:] - - remade_tokens = remade_tokens[:last_comma] - length = len(remade_tokens) - - rem = int(math.ceil(length / 75)) * 75 - length - remade_tokens += [id_end] * rem + reloc_tokens - multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults - - if embedding is None: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - iteration = len(remade_tokens) // 75 - if (len(remade_tokens) + emb_len) // 75 != iteration: - rem = (75 * (iteration + 1) - len(remade_tokens)) - remade_tokens += [id_end] * rem - multipliers += [1.0] * rem - iteration += 1 - fixes.append((iteration, (len(remade_tokens) % 75, embedding))) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - token_count = len(remade_tokens) - prompt_target_length = get_target_prompt_token_count(token_count) - tokens_to_add = prompt_target_length - len(remade_tokens) - - remade_tokens = remade_tokens + [id_end] * tokens_to_add - multipliers = multipliers + [1.0] * tokens_to_add - - return remade_tokens, fixes, multipliers, token_count - - def process_text(self, texts): - used_custom_terms = [] - remade_batch_tokens = [] - hijack_comments = [] - hijack_fixes = [] - token_count = 0 - - cache = {} - batch_multipliers = [] - for line in texts: - if line in cache: - remade_tokens, fixes, multipliers = cache[line] - else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) - token_count = max(current_token_count, token_count) - - cache[line] = (remade_tokens, fixes, multipliers) - - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) - - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - - def process_text_old(self, text): - id_start = self.wrapped.tokenizer.bos_token_id - id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length # you get to stay at 77 - used_custom_terms = [] - remade_batch_tokens = [] - overflowing_words = [] - hijack_comments = [] - hijack_fixes = [] - token_count = 0 - - cache = {} - batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] - batch_multipliers = [] - for tokens in batch_tokens: - tuple_tokens = tuple(tokens) - - if tuple_tokens in cache: - remade_tokens, fixes, multipliers = cache[tuple_tokens] - else: - fixes = [] - remade_tokens = [] - multipliers = [] - mult = 1.0 - - i = 0 - while i < len(tokens): - token = tokens[i] - - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - - mult_change = self.token_mults.get(token) if opts.enable_emphasis else None - if mult_change is not None: - mult *= mult_change - i += 1 - elif embedding is None: - remade_tokens.append(token) - multipliers.append(mult) - i += 1 - else: - emb_len = int(embedding.vec.shape[0]) - fixes.append((len(remade_tokens), embedding)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - used_custom_terms.append((embedding.name, embedding.checksum())) - i += embedding_length_in_tokens - - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] - cache[tuple_tokens] = (remade_tokens, fixes, multipliers) - - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] - - remade_batch_tokens.append(remade_tokens) - hijack_fixes.append(fixes) - batch_multipliers.append(multipliers) - return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - - def forward(self, text): - use_old = opts.use_old_emphasis_implementation - if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) - else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - - self.hijack.comments += hijack_comments - - if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - - if use_old: - self.hijack.fixes = hijack_fixes - return self.process_tokens(remade_batch_tokens, batch_multipliers) - - z = None - i = 0 - while max(map(len, remade_batch_tokens)) != 0: - rem_tokens = [x[75:] for x in remade_batch_tokens] - rem_multipliers = [x[75:] for x in batch_multipliers] - - self.hijack.fixes = [] - for unfiltered in hijack_fixes: - fixes = [] - for fix in unfiltered: - if fix[0] == i: - fixes.append(fix[1]) - self.hijack.fixes.append(fixes) - - tokens = [] - multipliers = [] - for j in range(len(remade_batch_tokens)): - if len(remade_batch_tokens[j]) > 0: - tokens.append(remade_batch_tokens[j][:75]) - multipliers.append(batch_multipliers[j][:75]) - else: - tokens.append([self.wrapped.tokenizer.eos_token_id] * 75) - multipliers.append([1.0] * 75) - - z1 = self.process_tokens(tokens, multipliers) - z = z1 if z is None else torch.cat((z, z1), axis=-2) - - remade_batch_tokens = rem_tokens - batch_multipliers = rem_multipliers - i += 1 - - return z - - def process_tokens(self, remade_batch_tokens, batch_multipliers): - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] - batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] - - tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - - if opts.CLIP_stop_at_last_layers > 1: - z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] - z = self.wrapped.transformer.text_model.final_layer_norm(z) - else: - z = outputs.last_hidden_state - - # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise - batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] - batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) - original_mean = z.mean() - z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) - new_mean = z.mean() - z *= original_mean / new_mean + return remade_batch_tokens[0], token_count, sd_hijack_clip.get_target_prompt_token_count(token_count) - return z class EmbeddingsWithFixes(torch.nn.Module): -- cgit v1.2.3 From 64c7b7975cedeb2aaa1a9c8eb4a479fc575843f8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 26 Nov 2022 16:45:57 +0300 Subject: restore hypernetworks to seemingly working state --- modules/sd_hijack.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d5243fd3..64655eb1 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,6 +9,7 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared +from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip @@ -60,7 +61,7 @@ def apply_optimizations(): def undo_optimizations(): - ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward # this stops hypernets from working + ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward -- cgit v1.2.3 From 98ca437edfbf71dd956d67d37f2136b12d13be0d Mon Sep 17 00:00:00 2001 From: brkirch Date: Sat, 12 Nov 2022 02:17:55 -0500 Subject: Refactor and instead check if mps is being used, not availability --- modules/sd_hijack.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index b824b5bf..ce583950 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -182,11 +182,7 @@ def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != devices.device: - - if devices.has_mps(): - attr = attr.to(device="mps", dtype=torch.float32) - else: - attr = attr.to(devices.device) + attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None)) setattr(self, name, attr) -- cgit v1.2.3 From 75c4511e6b81ae8fb0dbd932043e8eb35cd09f72 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 29 Nov 2022 10:28:41 +0800 Subject: add AltDiffusion to webui Signed-off-by: zhaohu xing <920232796@qq.com> --- modules/sd_hijack.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index eaedac13..26280fe4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -70,14 +70,19 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + + if shared.text_model_name == "XLMR-Large": + model_embeddings = m.cond_stage_model.roberta.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) + else : + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embeddings, self) - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) self.clip = m.cond_stage_model - apply_optimizations() + # apply_optimizations() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -125,8 +130,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.tokenizer = wrapped.tokenizer self.token_mults = {} - self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - + try: + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + except: + self.comma_token = None + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 @@ -298,6 +306,9 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): + if shared.text_model_name == "XLMR-Large": + return self.wrapped.encode(text) + use_old = opts.use_old_emphasis_implementation if use_old: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) @@ -359,7 +370,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z = self.wrapped.transformer.text_model.final_layer_norm(z) else: z = outputs.last_hidden_state - + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers] batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device) -- cgit v1.2.3 From 36c3613d16c523e43ec4dedbcbe9a3b93ad7d139 Mon Sep 17 00:00:00 2001 From: wywywywy Date: Tue, 29 Nov 2022 17:40:02 +0000 Subject: Add autoencoder to sd_hijack --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index b824b5bf..26f9b951 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -11,7 +11,7 @@ import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.hypernetworks import hypernetwork from modules.shared import opts, device, cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_autoencoder from modules.sd_hijack_optimizations import invokeAI_mps_available -- cgit v1.2.3 From 52cc83d36b7663a77b79fd2258d2ca871af73e55 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Wed, 30 Nov 2022 14:56:12 +0800 Subject: fix bugs Signed-off-by: zhaohu xing <920232796@qq.com> --- modules/sd_hijack.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3ec3f98a..edb8b420 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -28,7 +28,7 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At # new memory efficient cross attention blocks do not support hypernets and we already # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention -ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention +# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 ldm.modules.attention.print = lambda *args: None @@ -82,7 +82,12 @@ class StableDiffusionModelHijack: def hijack(self, m): - if type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: + if shared.text_model_name == "XLMR-Large": + model_embeddings = m.cond_stage_model.roberta.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) @@ -91,11 +96,7 @@ class StableDiffusionModelHijack: m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) apply_optimizations() - elif shared.text_model_name == "XLMR-Large": - model_embeddings = m.cond_stage_model.roberta.embeddings - model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - + self.clip = m.cond_stage_model fix_checkpoint() -- cgit v1.2.3 From da698ca92ed79b9104a62f34291d9b842c433a1b Mon Sep 17 00:00:00 2001 From: SmirkingFace <116507648+smirkingface@users.noreply.github.com> Date: Fri, 2 Dec 2022 13:47:02 +0100 Subject: Fixed AttributeError where openaimodel is not found --- modules/sd_hijack.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index b824b5bf..eef6efd2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -17,6 +17,7 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model +import ldm.modules.diffusionmodules.openaimodel import ldm.models.diffusion.ddim import ldm.models.diffusion.plms import ldm.modules.encoders.modules -- cgit v1.2.3 From 0d21624ceef52b843c731ddc7fdcd7b8d108a42e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 3 Dec 2022 18:16:19 +0300 Subject: move #5216 to the extension --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 303b1397..95a17093 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -11,7 +11,7 @@ import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.hypernetworks import hypernetwork from modules.shared import opts, device, cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_autoencoder +from modules import sd_hijack_clip, sd_hijack_open_clip from modules.sd_hijack_optimizations import invokeAI_mps_available -- cgit v1.2.3 From 4929503258d80abbc4b5f40da034298fe3803906 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 6 Dec 2022 09:03:55 +0800 Subject: fix bugs Signed-off-by: zhaohu xing <920232796@qq.com> --- modules/sd_hijack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index edb8b420..cd65d356 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -28,7 +28,7 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At # new memory efficient cross attention blocks do not support hypernets and we already # have memory efficient cross attention anyway, so this disables SD2.0's memory efficient cross attention ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention -# ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention +ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 ldm.modules.attention.print = lambda *args: None -- cgit v1.2.3 From 5dcc22606d05ebe5ae89c990bd83a3eb068fcb78 Mon Sep 17 00:00:00 2001 From: zhaohu xing <920232796@qq.com> Date: Tue, 6 Dec 2022 16:04:50 +0800 Subject: add hash and fix undo hijack bug Signed-off-by: zhaohu xing <920232796@qq.com> --- modules/sd_hijack.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 9b5890e7..9fed1b6f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -112,7 +112,11 @@ class StableDiffusionModelHijack: self.layers = flatten(m) def undo_hijack(self, m): - if type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: + + if shared.text_model_name == "XLMR-Large": + m.cond_stage_model = m.cond_stage_model.wrapped + + elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped model_embeddings = m.cond_stage_model.transformer.text_model.embeddings -- cgit v1.2.3 From 7dbfd8a7d8aefec7283b456c6f5b000ae4d3496d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:14:30 +0300 Subject: do not replace entire unet for the resolution hack --- modules/sd_hijack.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 92874a79..47dbc1b7 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -11,7 +11,7 @@ import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.hypernetworks import hypernetwork from modules.shared import opts, device, cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet from modules.sd_hijack_optimizations import invokeAI_mps_available @@ -35,11 +35,12 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] ldm.modules.attention.print = lambda *args: None ldm.modules.diffusionmodules.model.print = lambda *args: None + def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward + ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") -- cgit v1.2.3 From 505ec7e4d960e7bea579182509050fafb10bd00c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 10 Dec 2022 09:17:39 +0300 Subject: cleanup some unneeded imports for hijack files --- modules/sd_hijack.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 47dbc1b7..690a9ec2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -1,16 +1,10 @@ -import math -import os -import sys -import traceback import torch -import numpy as np -from torch import einsum from torch.nn.functional import silu import modules.textual_inversion.textual_inversion -from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint +from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.hypernetworks import hypernetwork -from modules.shared import opts, device, cmd_opts +from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet from modules.sd_hijack_optimizations import invokeAI_mps_available -- cgit v1.2.3 From f34c7341720fb2059992926c9f9ae6ff25f7385b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 31 Dec 2022 18:06:35 +0300 Subject: alt-diffusion integration --- modules/sd_hijack.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bce23b03..edcbaf52 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr from modules.sd_hijack_optimizations import invokeAI_mps_available @@ -68,6 +68,7 @@ def fix_checkpoint(): ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward + class StableDiffusionModelHijack: fixes = None comments = [] @@ -79,21 +80,22 @@ class StableDiffusionModelHijack: def hijack(self, m): - if shared.text_model_name == "XLMR-Large": + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - + m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - apply_optimizations() + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder: m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - apply_optimizations() - + + apply_optimizations() + self.clip = m.cond_stage_model fix_checkpoint() @@ -109,7 +111,7 @@ class StableDiffusionModelHijack: def undo_hijack(self, m): - if shared.text_model_name == "XLMR-Large": + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: m.cond_stage_model = m.cond_stage_model.wrapped elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: -- cgit v1.2.3 From 21ee77db314ede7ccbb18787962347c09a4df0c7 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Wed, 4 Jan 2023 08:04:38 -0500 Subject: add cross-attention info --- modules/sd_hijack.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index edcbaf52..fa2cd4bb 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -35,26 +35,35 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + + optimization_method = None if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward + optimization_method = 'xformers' elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + optimization_method = 'V1' elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): if not invokeAI_mps_available and shared.device.type == 'mps': print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + optimization_method = 'V1' else: print("Applying cross attention optimization (InvokeAI).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI + optimization_method = 'InvokeAI' elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward + optimization_method = 'Doggettx' + + return optimization_method def undo_optimizations(): @@ -75,6 +84,7 @@ class StableDiffusionModelHijack: layers = None circular_enabled = False clip = None + optimization_method = None embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) @@ -94,7 +104,7 @@ class StableDiffusionModelHijack: m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) - apply_optimizations() + self.optimization_method = apply_optimizations() self.clip = m.cond_stage_model -- cgit v1.2.3