From 8deae077004f0332ca607fc3a5d568b1a4705bec Mon Sep 17 00:00:00 2001 From: d8ahazard Date: Fri, 30 Sep 2022 15:28:37 -0500 Subject: Add ScuNET DeNoiser/Upscaler Q&D Implementation of ScuNET, thanks to our handy model loader. :P https://github.com/cszn/SCUNet --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 8428c7a3..a48b995a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -40,6 +40,7 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) +parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(model_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR')) parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") -- cgit v1.2.3 From 820f1dc96b1979d7e92170c161db281ee8bd988b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 15:03:39 +0300 Subject: initial support for training textual inversion --- .gitignore | 1 + javascript/progressbar.js | 1 + javascript/textualInversion.js | 8 + modules/devices.py | 3 +- modules/processing.py | 13 +- modules/sd_hijack.py | 324 ++++------------------ modules/sd_hijack_optimizations.py | 164 +++++++++++ modules/sd_models.py | 4 +- modules/shared.py | 3 +- modules/textual_inversion/dataset.py | 76 +++++ modules/textual_inversion/textual_inversion.py | 258 +++++++++++++++++ modules/textual_inversion/ui.py | 32 +++ modules/ui.py | 139 ++++++++-- style.css | 10 +- textual_inversion_templates/style.txt | 19 ++ textual_inversion_templates/style_filewords.txt | 19 ++ textual_inversion_templates/subject.txt | 27 ++ textual_inversion_templates/subject_filewords.txt | 27 ++ webui.py | 15 +- 19 files changed, 828 insertions(+), 315 deletions(-) create mode 100644 javascript/textualInversion.js create mode 100644 modules/sd_hijack_optimizations.py create mode 100644 modules/textual_inversion/dataset.py create mode 100644 modules/textual_inversion/textual_inversion.py create mode 100644 modules/textual_inversion/ui.py create mode 100644 textual_inversion_templates/style.txt create mode 100644 textual_inversion_templates/style_filewords.txt create mode 100644 textual_inversion_templates/subject.txt create mode 100644 textual_inversion_templates/subject_filewords.txt (limited to 'modules/shared.py') diff --git a/.gitignore b/.gitignore index 3532dab3..7afc9395 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ __pycache__ /.idea notification.mp3 /SwinIR +/textual_inversion diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 21f25b38..1e297abb 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte onUiUpdate(function(){ check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') + check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery') }) function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js new file mode 100644 index 00000000..8061be08 --- /dev/null +++ b/javascript/textualInversion.js @@ -0,0 +1,8 @@ + + +function start_training_textual_inversion(){ + requestProgress('ti') + gradioApp().querySelector('#ti_error').innerHTML='' + + return args_to_array(arguments) +} diff --git a/modules/devices.py b/modules/devices.py index 07bb2339..ff82f2f6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -32,10 +32,9 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") - device = get_optimal_device() device_codeformer = cpu if has_mps else device - +dtype = torch.float16 def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. diff --git a/modules/processing.py b/modules/processing.py index 7eeb5191..8223423a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -56,7 +56,7 @@ class StableDiffusionProcessing: self.prompt: str = prompt self.prompt_for_display: str = None self.negative_prompt: str = (negative_prompt or "") - self.styles: str = styles + self.styles: list = styles or [] self.seed: int = seed self.subseed: int = subseed self.subseed_strength: float = subseed_strength @@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), - "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta), + "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), } generation_params.update(p.extra_generation_params) @@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: fix_seed(p) - os.makedirs(p.outpath_samples, exist_ok=True) - os.makedirs(p.outpath_grids, exist_ok=True) + if p.outpath_samples is not None: + os.makedirs(p.outpath_samples, exist_ok=True) + + if p.outpath_grids is not None: + os.makedirs(p.outpath_grids, exist_ok=True) modules.sd_hijack.model_hijack.apply_circular(p.tiling) @@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) if os.path.exists(cmd_opts.embeddings_dir): - model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model) + model_hijack.embedding_db.load_textual_inversion_embeddings() infotexts = [] output_images = [] diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index fa7eaeb8..fd57e5c5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,244 +6,41 @@ import torch import numpy as np from torch import einsum -from modules import prompt_parser +import modules.textual_inversion.textual_inversion +from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules.shared import opts, device, cmd_opts -from ldm.util import default -from einops import rearrange import ldm.modules.attention import ldm.modules.diffusionmodules.model +attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward +diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity +diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward -# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): - h = self.heads - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - del context, x +def apply_optimizations(): + if cmd_opts.opt_split_attention_v1: + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward + ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack + ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale +def undo_optimizations(): + ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward + ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward - s2 = s1.softmax(dim=-1) - del s1 - - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - - -# taken from https://github.com/Doggettx/stable-diffusion -def split_cross_attention_forward(self, x, context=None, mask=None): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - - del q, k, v - - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - - return self.to_out(r2) - -def nonlinearity_hijack(x): - # swish - t = torch.sigmoid(x) - x *= t - del t - - return x - -def cross_attention_attnblock_forward(self, x): - h_ = x - h_ = self.norm(h_) - q1 = self.q(h_) - k1 = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q1.shape - - q2 = q1.reshape(b, c, h*w) - del q1 - - q = q2.permute(0, 2, 1) # b,hw,c - del q2 - - k = k1.reshape(b, c, h*w) # b,c,hw - del k1 - - h_ = torch.zeros_like(k, device=q.device) - - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch - - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() - mem_required = tensor_size * 2.5 - steps = 1 - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - - w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w2 = w1 * (int(c)**(-0.5)) - del w1 - w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) - del w2 - - # attend to values - v1 = v.reshape(b, c, h*w) - w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w3 - - h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w4 - - h2 = h_.reshape(b, c, h, w) - del h_ - - h3 = self.proj_out(h2) - del h2 - - h3 += x - - return h3 class StableDiffusionModelHijack: - ids_lookup = {} - word_embeddings = {} - word_embeddings_checksums = {} fixes = None comments = [] - dir_mtime = None layers = None circular_enabled = False clip = None - def load_textual_inversion_embeddings(self, dirname, model): - mt = os.path.getmtime(dirname) - if self.dir_mtime is not None and mt <= self.dir_mtime: - return - - self.dir_mtime = mt - self.ids_lookup.clear() - self.word_embeddings.clear() - - tokenizer = model.cond_stage_model.tokenizer - - def const_hash(a): - r = 0 - for v in a: - r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF - return r - - def process_file(path, filename): - name = os.path.splitext(filename)[0] - - data = torch.load(path, map_location="cpu") - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - if hasattr(param_dict, '_parameters'): - param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - # diffuser concepts - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - - self.word_embeddings[name] = emb.detach().to(device) - self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}' - - ids = tokenizer([name], add_special_tokens=False)['input_ids'][0] - - first_id = ids[0] - if first_id not in self.ids_lookup: - self.ids_lookup[first_id] = [] - self.ids_lookup[first_id].append((ids, name)) - - for fn in os.listdir(dirname): - try: - fullfn = os.path.join(dirname, fn) - - if os.stat(fullfn).st_size == 0: - continue - - process_file(fullfn, fn) - except Exception: - print(f"Error loading emedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - continue - - print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) def hijack(self, m): model_embeddings = m.cond_stage_model.transformer.text_model.embeddings @@ -253,12 +50,7 @@ class StableDiffusionModelHijack: self.clip = m.cond_stage_model - if cmd_opts.opt_split_attention_v1: - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward - ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack - ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + apply_optimizations() def flatten(el): flattened = [flatten(children) for children in el.children()] @@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.hijack = hijack + self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer self.max_length = wrapped.max_length self.token_mults = {} @@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) - if possible_matches is None: + if embedding is None: remade_tokens.append(token) multipliers.append(weight) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i + len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [weight] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(weight) - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - possible_matches = self.hijack.ids_lookup.get(token, None) + embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: mult *= mult_change - elif possible_matches is None: + i += 1 + elif embedding is None: remade_tokens.append(token) multipliers.append(mult) + i += 1 else: - found = False - for ids, word in possible_matches: - if tokens[i:i+len(ids)] == ids: - emb_len = int(self.hijack.word_embeddings[word].shape[0]) - fixes.append((len(remade_tokens), word)) - remade_tokens += [0] * emb_len - multipliers += [mult] * emb_len - i += len(ids) - 1 - found = True - used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) - break - - if not found: - remade_tokens.append(token) - multipliers.append(mult) - - i += 1 + emb_len = int(embedding.vec.shape[0]) + fixes.append((len(remade_tokens), embedding)) + remade_tokens += [0] * emb_len + multipliers += [mult] * emb_len + used_custom_terms.append((embedding.name, embedding.checksum())) + i += emb_len if len(remade_tokens) > maxlen - 2: vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} @@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] @@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) - self.hijack.fixes = hijack_fixes self.hijack.comments = hijack_comments @@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module): inputs_embeds = self.wrapped(input_ids) - if batch_fixes is not None: - for fixes, tensor in zip(batch_fixes, inputs_embeds): - for offset, word in fixes: - emb = self.embeddings.word_embeddings[word] - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len] + if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0: + return inputs_embeds + + vecs = [] + for fixes, tensor in zip(batch_fixes, inputs_embeds): + for offset, embedding in fixes: + emb = embedding.vec + emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + + vecs.append(tensor) - return inputs_embeds + return torch.stack(vecs) def add_circular_option_to_conv_2d(): diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py new file mode 100644 index 00000000..9c079e57 --- /dev/null +++ b/modules/sd_hijack_optimizations.py @@ -0,0 +1,164 @@ +import math +import torch +from torch import einsum + +from ldm.util import default +from einops import rearrange + + +# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion +def split_cross_attention_forward_v1(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + + +# taken from https://github.com/Doggettx/stable-diffusion +def split_cross_attention_forward(self, x, context=None, mask=None): + h = self.heads + + q_in = self.to_q(x) + context = default(context, x) + k_in = self.to_k(context) * self.scale + v_in = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) + +def nonlinearity_hijack(x): + # swish + t = torch.sigmoid(x) + x *= t + del t + + return x + +def cross_attention_attnblock_forward(self, x): + h_ = x + h_ = self.norm(h_) + q1 = self.q(h_) + k1 = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 + + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 + + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 diff --git a/modules/sd_models.py b/modules/sd_models.py index 2539f14c..5b3dbdc7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -8,7 +8,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader +from modules import shared, modelloader, devices from modules.paths import models_path model_dir = "Stable-diffusion" @@ -134,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): if not shared.cmd_opts.no_half: model.half() + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + model.sd_model_hash = sd_model_hash model.sd_model_checkpint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index ac968b2d..ac0bc480 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -78,6 +78,7 @@ class State: current_latent = None current_image = None current_image_sampling_step = 0 + textinfo = None def interrupt(self): self.interrupted = True @@ -88,7 +89,7 @@ class State: self.current_image_sampling_step = 0 def get_job_timestamp(self): - return datetime.datetime.now().strftime("%Y%m%d%H%M%S") + return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? state = State() diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py new file mode 100644 index 00000000..7e134a08 --- /dev/null +++ b/modules/textual_inversion/dataset.py @@ -0,0 +1,76 @@ +import os +import numpy as np +import PIL +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random +import tqdm + + +class PersonalizedBase(Dataset): + def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None): + + self.placeholder_token = placeholder_token + + self.size = size + self.width = width + self.height = height + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + self.dataset = [] + + with open(template_file, "r") as file: + lines = [x.strip() for x in file.readlines()] + + self.lines = lines + + assert data_root, 'dataset directory not specified' + + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] + print("Preparing dataset...") + for path in tqdm.tqdm(self.image_paths): + image = Image.open(path) + image = image.convert('RGB') + image = image.resize((self.width, self.height), PIL.Image.BICUBIC) + + filename = os.path.basename(path) + filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-') + filename_tokens = [token for token in filename_tokens if token.isalpha()] + + npimage = np.array(image).astype(np.uint8) + npimage = (npimage / 127.5 - 1.0).astype(np.float32) + + torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32) + torchdata = torch.moveaxis(torchdata, 2, 0) + + init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() + + self.dataset.append((init_latent, filename_tokens)) + + self.length = len(self.dataset) * repeats + + self.initial_indexes = np.arange(self.length) % len(self.dataset) + self.indexes = None + self.shuffle() + + def shuffle(self): + self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + + def __len__(self): + return self.length + + def __getitem__(self, i): + if i % len(self.dataset) == 0: + self.shuffle() + + index = self.indexes[i % len(self.indexes)] + x, filename_tokens = self.dataset[index] + + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + text = text.replace("[filewords]", ' '.join(filename_tokens)) + + return x, text diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py new file mode 100644 index 00000000..c0baaace --- /dev/null +++ b/modules/textual_inversion/textual_inversion.py @@ -0,0 +1,258 @@ +import os +import sys +import traceback + +import torch +import tqdm +import html +import datetime + +from modules import shared, devices, sd_hijack, processing +import modules.textual_inversion.dataset + + +class Embedding: + def __init__(self, vec, name, step=None): + self.vec = vec + self.name = name + self.step = step + self.cached_checksum = None + + def save(self, filename): + embedding_data = { + "string_to_token": {"*": 265}, + "string_to_param": {"*": self.vec}, + "name": self.name, + "step": self.step, + } + + torch.save(embedding_data, filename) + + def checksum(self): + if self.cached_checksum is not None: + return self.cached_checksum + + def const_hash(a): + r = 0 + for v in a: + r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF + return r + + self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' + return self.cached_checksum + +class EmbeddingDatabase: + def __init__(self, embeddings_dir): + self.ids_lookup = {} + self.word_embeddings = {} + self.dir_mtime = None + self.embeddings_dir = embeddings_dir + + def register_embedding(self, embedding, model): + + self.word_embeddings[embedding.name] = embedding + + ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0] + + first_id = ids[0] + if first_id not in self.ids_lookup: + self.ids_lookup[first_id] = [] + self.ids_lookup[first_id].append((ids, embedding)) + + return embedding + + def load_textual_inversion_embeddings(self): + mt = os.path.getmtime(self.embeddings_dir) + if self.dir_mtime is not None and mt <= self.dir_mtime: + return + + self.dir_mtime = mt + self.ids_lookup.clear() + self.word_embeddings.clear() + + def process_file(path, filename): + name = os.path.splitext(filename)[0] + + data = torch.load(path, map_location="cpu") + + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + if hasattr(param_dict, '_parameters'): + param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + # diffuser concepts + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + vec = emb.detach().to(devices.device, dtype=torch.float32) + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + self.register_embedding(embedding, shared.sd_model) + + for fn in os.listdir(self.embeddings_dir): + try: + fullfn = os.path.join(self.embeddings_dir, fn) + + if os.stat(fullfn).st_size == 0: + continue + + process_file(fullfn, fn) + except Exception: + print(f"Error loading emedding {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + continue + + print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") + + def find_embedding_at_position(self, tokens, offset): + token = tokens[offset] + possible_matches = self.ids_lookup.get(token, None) + + if possible_matches is None: + return None + + for ids, embedding in possible_matches: + if tokens[offset:offset + len(ids)] == ids: + return embedding + + return None + + + +def create_embedding(name, num_vectors_per_token): + init_text = '*' + + cond_model = shared.sd_model.cond_stage_model + embedding_layer = cond_model.wrapped.transformer.text_model.embeddings + + ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer(ids.to(devices.device)).squeeze(0) + vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) + + for i in range(num_vectors_per_token): + vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] + + fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + embedding = Embedding(vec, name) + embedding.step = 0 + embedding.save(fn) + + return fn + + +def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file): + assert embedding_name, 'embedding not selected' + + shared.state.textinfo = "Initializing textual inversion training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name) + + if save_embedding_every > 0: + embedding_dir = os.path.join(log_directory, "embeddings") + os.makedirs(embedding_dir, exist_ok=True) + else: + embedding_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + cond_model = shared.sd_model.cond_stage_model + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + + hijack = sd_hijack.model_hijack + + embedding = hijack.embedding_db.word_embeddings[embedding_name] + embedding.vec.requires_grad = True + + optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = embedding.step or 0 + if ititial_step > steps: + return embedding, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + for i, (x, text) in pbar: + embedding.step = i + ititial_step + + if embedding.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([text]) + loss = shared.sd_model(x.unsqueeze(0), c)[0] + + losses[embedding.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: + last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') + embedding.save(last_saved_file) + + if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {text}" + + shared.state.job_no = embedding.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {embedding.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + embedding.cached_checksum = None + embedding.save(filename) + + return embedding, filename + diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py new file mode 100644 index 00000000..ce3677a9 --- /dev/null +++ b/modules/textual_inversion/ui.py @@ -0,0 +1,32 @@ +import html + +import gradio as gr + +import modules.textual_inversion.textual_inversion as ti +from modules import sd_hijack, shared + + +def create_embedding(name, nvpt): + filename = ti.create_embedding(name, nvpt) + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" + + +def train_embedding(*args): + + try: + sd_hijack.undo_optimizations() + + embedding, filename = ti.train_embedding(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps. +Embedding saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + sd_hijack.apply_optimizations() diff --git a/modules/ui.py b/modules/ui.py index 15572bb0..57aef6ff 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -21,6 +21,7 @@ import gradio as gr import gradio.utils import gradio.routes +from modules import sd_hijack from modules.paths import script_path from modules.shared import opts, cmd_opts import modules.shared as shared @@ -32,6 +33,7 @@ import modules.gfpgan_model import modules.codeformer_model import modules.styles import modules.generation_parameters_copypaste +import modules.textual_inversion.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI mimetypes.init() @@ -142,8 +144,8 @@ def save_files(js_data, images, index): return '', '', plaintext_to_html(f"Saved: {filenames[0]}") -def wrap_gradio_call(func): - def f(*args, **kwargs): +def wrap_gradio_call(func, extra_outputs=None): + def f(*args, extra_outputs_array=extra_outputs, **kwargs): run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled if run_memmon: shared.mem_mon.monitor() @@ -159,7 +161,10 @@ def wrap_gradio_call(func): shared.state.job = "" shared.state.job_count = 0 - res = [None, '', f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + if extra_outputs_array is None: + extra_outputs_array = [None, ''] + + res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] elapsed = time.perf_counter() - t @@ -179,6 +184,7 @@ def wrap_gradio_call(func): res[-1] += f"

Time taken: {elapsed:.2f}s

{vram_html}
" shared.state.interrupted = False + shared.state.job_count = 0 return tuple(res) @@ -187,7 +193,7 @@ def wrap_gradio_call(func): def check_progress_call(id_part): if shared.state.job_count == 0: - return "", gr_show(False), gr_show(False) + return "", gr_show(False), gr_show(False), gr_show(False) progress = 0 @@ -219,13 +225,19 @@ def check_progress_call(id_part): else: preview_visibility = gr_show(True) - return f"

{progressbar}

", preview_visibility, image + if shared.state.textinfo is not None: + textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) + else: + textinfo_result = gr_show(False) + + return f"

{progressbar}

", preview_visibility, image, textinfo_result def check_progress_call_initial(id_part): shared.state.job_count = -1 shared.state.current_latent = None shared.state.current_image = None + shared.state.textinfo = None return check_progress_call(id_part) @@ -399,13 +411,16 @@ def create_toprow(is_img2img): return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste -def setup_progressbar(progressbar, preview, id_part): +def setup_progressbar(progressbar, preview, id_part, textinfo=None): + if textinfo is None: + textinfo = gr.HTML(visible=False) + check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) check_progress.click( fn=lambda: check_progress_call(id_part), show_progress=False, inputs=[], - outputs=[progressbar, preview, preview], + outputs=[progressbar, preview, preview, textinfo], ) check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) @@ -413,11 +428,14 @@ def setup_progressbar(progressbar, preview, id_part): fn=lambda: check_progress_call_initial(id_part), show_progress=False, inputs=[], - outputs=[progressbar, preview, preview], + outputs=[progressbar, preview, preview, textinfo], ) -def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): +def create_ui(wrap_gradio_gpu_call): + import modules.img2img + import modules.txt2img + with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) @@ -483,7 +501,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) txt2img_args = dict( - fn=txt2img, + fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), _js="submit", inputs=[ txt2img_prompt, @@ -675,7 +693,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ) img2img_args = dict( - fn=img2img, + fn=wrap_gradio_gpu_call(modules.img2img.img2img), _js="submit_img2img", inputs=[ dummy_component, @@ -828,7 +846,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): open_extras_folder = gr.Button('Open output directory', elem_id=button_id) submit.click( - fn=run_extras, + fn=wrap_gradio_gpu_call(modules.extras.run_extras), _js="get_extras_tab_index", inputs=[ dummy_component, @@ -878,7 +896,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): pnginfo_send_to_img2img = gr.Button('Send to img2img') image.change( - fn=wrap_gradio_call(run_pnginfo), + fn=wrap_gradio_call(modules.extras.run_pnginfo), inputs=[image], outputs=[html, generation_info, html2], ) @@ -887,7 +905,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") - + with gr.Row(): primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") @@ -896,10 +914,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") save_as_half = gr.Checkbox(value=False, label="Safe as float16") modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') - + with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() + + with gr.Blocks() as textual_inversion_interface: + with gr.Row().style(equal_height=False): + with gr.Column(): + with gr.Group(): + gr.HTML(value="

Create a new embedding

") + + new_embedding_name = gr.Textbox(label="Name") + nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) + + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding = gr.Button(value="Create", variant='primary') + + with gr.Group(): + gr.HTML(value="

Train an embedding; must specify a directory with a set of 512x512 images

") + train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + learn_rate = gr.Number(label='Learning rate', value=5.0e-03) + dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") + log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") + template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) + steps = gr.Number(label='Max steps', value=100000, precision=0) + create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0) + save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0) + + with gr.Row(): + with gr.Column(scale=2): + gr.HTML(value="") + + with gr.Column(): + with gr.Row(): + interrupt_training = gr.Button(value="Interrupt") + train_embedding = gr.Button(value="Train", variant='primary') + + with gr.Column(): + progressbar = gr.HTML(elem_id="ti_progressbar") + ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) + + ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) + ti_preview = gr.Image(elem_id='ti_preview', visible=False) + ti_progress = gr.HTML(elem_id="ti_progress", value="") + ti_outcome = gr.HTML(elem_id="ti_error", value="") + setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) + + create_embedding.click( + fn=modules.textual_inversion.ui.create_embedding, + inputs=[ + new_embedding_name, + nvpt, + ], + outputs=[ + train_embedding_name, + ti_output, + ti_outcome, + ] + ) + + train_embedding.click( + fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + _js="start_training_textual_inversion", + inputs=[ + train_embedding_name, + learn_rate, + dataset_directory, + log_directory, + steps, + create_image_every, + save_embedding_every, + template_file, + ], + outputs=[ + ti_output, + ti_outcome, + ] + ) + + interrupt_training.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) + def create_setting_component(key): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -1011,6 +1115,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (textual_inversion_interface, "Textual inversion", "ti"), (settings_interface, "Settings", "settings"), ] @@ -1044,11 +1149,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def modelmerger(*args): try: - results = run_modelmerger(*args) + results = modules.extras.run_modelmerger(*args) except Exception as e: print("Error loading/saving model file:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - modules.sd_models.list_models() #To remove the potentially missing models from the list + modules.sd_models.list_models() # to remove the potentially missing models from the list return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)] return results diff --git a/style.css b/style.css index 79d6bb0d..39586bf1 100644 --- a/style.css +++ b/style.css @@ -157,7 +157,7 @@ button{ max-width: 10em; } -#txt2img_preview, #img2img_preview{ +#txt2img_preview, #img2img_preview, #ti_preview{ position: absolute; width: 320px; left: 0; @@ -172,18 +172,18 @@ button{ } @media screen and (min-width: 768px) { - #txt2img_preview, #img2img_preview { + #txt2img_preview, #img2img_preview, #ti_preview { position: absolute; } } @media screen and (max-width: 767px) { - #txt2img_preview, #img2img_preview { + #txt2img_preview, #img2img_preview, #ti_preview { position: relative; } } -#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{ +#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{ display: none; } @@ -247,7 +247,7 @@ input[type="range"]{ #txt2img_negative_prompt, #img2img_negative_prompt{ } -#txt2img_progressbar, #img2img_progressbar{ +#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ position: absolute; z-index: 1000; right: 0; diff --git a/textual_inversion_templates/style.txt b/textual_inversion_templates/style.txt new file mode 100644 index 00000000..15af2d6b --- /dev/null +++ b/textual_inversion_templates/style.txt @@ -0,0 +1,19 @@ +a painting, art by [name] +a rendering, art by [name] +a cropped painting, art by [name] +the painting, art by [name] +a clean painting, art by [name] +a dirty painting, art by [name] +a dark painting, art by [name] +a picture, art by [name] +a cool painting, art by [name] +a close-up painting, art by [name] +a bright painting, art by [name] +a cropped painting, art by [name] +a good painting, art by [name] +a close-up painting, art by [name] +a rendition, art by [name] +a nice painting, art by [name] +a small painting, art by [name] +a weird painting, art by [name] +a large painting, art by [name] diff --git a/textual_inversion_templates/style_filewords.txt b/textual_inversion_templates/style_filewords.txt new file mode 100644 index 00000000..b3a8159a --- /dev/null +++ b/textual_inversion_templates/style_filewords.txt @@ -0,0 +1,19 @@ +a painting of [filewords], art by [name] +a rendering of [filewords], art by [name] +a cropped painting of [filewords], art by [name] +the painting of [filewords], art by [name] +a clean painting of [filewords], art by [name] +a dirty painting of [filewords], art by [name] +a dark painting of [filewords], art by [name] +a picture of [filewords], art by [name] +a cool painting of [filewords], art by [name] +a close-up painting of [filewords], art by [name] +a bright painting of [filewords], art by [name] +a cropped painting of [filewords], art by [name] +a good painting of [filewords], art by [name] +a close-up painting of [filewords], art by [name] +a rendition of [filewords], art by [name] +a nice painting of [filewords], art by [name] +a small painting of [filewords], art by [name] +a weird painting of [filewords], art by [name] +a large painting of [filewords], art by [name] diff --git a/textual_inversion_templates/subject.txt b/textual_inversion_templates/subject.txt new file mode 100644 index 00000000..79f36aa0 --- /dev/null +++ b/textual_inversion_templates/subject.txt @@ -0,0 +1,27 @@ +a photo of a [name] +a rendering of a [name] +a cropped photo of the [name] +the photo of a [name] +a photo of a clean [name] +a photo of a dirty [name] +a dark photo of the [name] +a photo of my [name] +a photo of the cool [name] +a close-up photo of a [name] +a bright photo of the [name] +a cropped photo of a [name] +a photo of the [name] +a good photo of the [name] +a photo of one [name] +a close-up photo of the [name] +a rendition of the [name] +a photo of the clean [name] +a rendition of a [name] +a photo of a nice [name] +a good photo of a [name] +a photo of the nice [name] +a photo of the small [name] +a photo of the weird [name] +a photo of the large [name] +a photo of a cool [name] +a photo of a small [name] diff --git a/textual_inversion_templates/subject_filewords.txt b/textual_inversion_templates/subject_filewords.txt new file mode 100644 index 00000000..008652a6 --- /dev/null +++ b/textual_inversion_templates/subject_filewords.txt @@ -0,0 +1,27 @@ +a photo of a [name], [filewords] +a rendering of a [name], [filewords] +a cropped photo of the [name], [filewords] +the photo of a [name], [filewords] +a photo of a clean [name], [filewords] +a photo of a dirty [name], [filewords] +a dark photo of the [name], [filewords] +a photo of my [name], [filewords] +a photo of the cool [name], [filewords] +a close-up photo of a [name], [filewords] +a bright photo of the [name], [filewords] +a cropped photo of a [name], [filewords] +a photo of the [name], [filewords] +a good photo of the [name], [filewords] +a photo of one [name], [filewords] +a close-up photo of the [name], [filewords] +a rendition of the [name], [filewords] +a photo of the clean [name], [filewords] +a rendition of a [name], [filewords] +a photo of a nice [name], [filewords] +a good photo of a [name], [filewords] +a photo of the nice [name], [filewords] +a photo of the small [name], [filewords] +a photo of the weird [name], [filewords] +a photo of the large [name], [filewords] +a photo of a cool [name], [filewords] +a photo of a small [name], [filewords] diff --git a/webui.py b/webui.py index b8cccd54..19fdcdd4 100644 --- a/webui.py +++ b/webui.py @@ -12,7 +12,6 @@ import modules.bsrgan_model as bsrgan import modules.extras import modules.face_restoration import modules.gfpgan_model as gfpgan -import modules.img2img import modules.ldsr_model as ldsr import modules.lowvram import modules.realesrgan_model as realesrgan @@ -21,7 +20,6 @@ import modules.sd_hijack import modules.sd_models import modules.shared as shared import modules.swinir_model as swinir -import modules.txt2img import modules.ui from modules import modelloader from modules.paths import script_path @@ -46,7 +44,7 @@ def wrap_queued_call(func): return f -def wrap_gradio_gpu_call(func): +def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): devices.torch_gc() @@ -58,6 +56,7 @@ def wrap_gradio_gpu_call(func): shared.state.current_image = None shared.state.current_image_sampling_step = 0 shared.state.interrupted = False + shared.state.textinfo = None with queue_lock: res = func(*args, **kwargs) @@ -69,7 +68,7 @@ def wrap_gradio_gpu_call(func): return res - return modules.ui.wrap_gradio_call(f) + return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) modules.scripts.load_scripts(os.path.join(script_path, "scripts")) @@ -86,13 +85,7 @@ def webui(): signal.signal(signal.SIGINT, sigint_handler) - demo = modules.ui.create_ui( - txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img), - img2img=wrap_gradio_gpu_call(modules.img2img.img2img), - run_extras=wrap_gradio_gpu_call(modules.extras.run_extras), - run_pnginfo=modules.extras.run_pnginfo, - run_modelmerger=modules.extras.run_modelmerger - ) + demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) demo.launch( share=cmd_opts.share, -- cgit v1.2.3 From 3ff0de2c594b786ef948a89efb1814c59bb42117 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 20:23:40 +0300 Subject: added --disable-console-progressbars to disable progressbars in console disabled printing prompts to console by default, enabled by --enable-console-prompts --- modules/img2img.py | 4 +++- modules/sd_samplers.py | 8 ++++++-- modules/shared.py | 7 +++++-- modules/txt2img.py | 4 +++- 4 files changed, 17 insertions(+), 6 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/img2img.py b/modules/img2img.py index 03e934e9..f4455c90 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -103,7 +103,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpaint_full_res_padding=inpaint_full_res_padding, inpainting_mask_invert=inpainting_mask_invert, ) - print(f"\nimg2img: {prompt}", file=shared.progress_print_out) + + if shared.cmd_opts.enable_console_prompts: + print(f"\nimg2img: {prompt}", file=shared.progress_print_out) p.extra_generation_params["Mask blur"] = mask_blur diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 92522214..9316875a 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -77,7 +77,9 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs): state.sampling_steps = len(sequence) state.sampling_step = 0 - for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs): + seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs) + + for x in seq: if state.interrupted: break @@ -207,7 +209,9 @@ def extended_trange(sampler, count, *args, **kwargs): state.sampling_steps = count state.sampling_step = 0 - for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs): + seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs) + + for x in seq: if state.interrupted: break diff --git a/modules/shared.py b/modules/shared.py index 5a591dc9..1bf7a6c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -58,6 +58,9 @@ parser.add_argument("--opt-channelslast", action='store_true', help="change memo parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) +parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) +parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) + cmd_opts = parser.parse_args() device = get_optimal_device() @@ -320,14 +323,14 @@ class TotalTQDM: ) def update(self): - if not opts.multiple_tqdm: + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: return if self._tqdm is None: self.reset() self._tqdm.update() def updateTotal(self, new_total): - if not opts.multiple_tqdm: + if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: return if self._tqdm is None: self.reset() diff --git a/modules/txt2img.py b/modules/txt2img.py index 5368e4d0..d4406c3c 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -34,7 +34,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: denoising_strength=denoising_strength if enable_hr else None, ) - print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + if cmd_opts.enable_console_prompts: + print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) + processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: -- cgit v1.2.3 From 91f327f22bb2feb780c424c74723cc0629dc34a1 Mon Sep 17 00:00:00 2001 From: Lopyter Date: Sun, 2 Oct 2022 18:15:31 +0200 Subject: make save to dirs optional for imgs saved from ui --- modules/shared.py | 1 + modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 1bf7a6c1..785e7af6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -173,6 +173,7 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo "grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"), "directories_filename_pattern": OptionInfo("", "Directory name pattern"), "directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), + "use_save_to_dirs_for_ui": OptionInfo(False, "Use \"Save images to a subdirectory\" option for images saved from UI"), })) options_templates.update(options_section(('upscaling', "Upscaling"), { diff --git a/modules/ui.py b/modules/ui.py index 78a15d83..8912deff 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -113,7 +113,7 @@ def save_files(js_data, images, index): p = MyObject(data) path = opts.outdir_save - save_to_dirs = opts.save_to_dirs + save_to_dirs = opts.use_save_to_dirs_for_ui if save_to_dirs: dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, p.seed, p.prompt) -- cgit v1.2.3 From c4445225f79f1c57afe52358ff4b205864eb7aac Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 2 Oct 2022 21:50:14 +0300 Subject: change wording for options --- modules/shared.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 785e7af6..7246eadc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -170,10 +170,10 @@ options_templates.update(options_section(('saving-paths', "Paths for saving"), { options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"), - "grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"), + "grid_save_to_dirs": OptionInfo(False, "Save grids to a subdirectory"), + "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), "directories_filename_pattern": OptionInfo("", "Directory name pattern"), - "directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), - "use_save_to_dirs_for_ui": OptionInfo(False, "Use \"Save images to a subdirectory\" option for images saved from UI"), + "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), })) options_templates.update(options_section(('upscaling', "Upscaling"), { -- cgit v1.2.3 From 138662734c25dab4e73e632b7eaff9ad9c0ce2b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 3 Oct 2022 07:57:59 +0300 Subject: use dropdown instead of radio for img2img upscaler selection --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 7246eadc..2a599e9c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -183,7 +183,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), - "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), + "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { -- cgit v1.2.3 From eeab7aedf532680a6ae9058ee272450bb07e41eb Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 4 Oct 2022 04:24:35 -0400 Subject: Add --use-cpu command line option Remove MPS detection to use CPU for GFPGAN / CodeFormer and add a --use-cpu command line option. --- modules/devices.py | 5 ++--- modules/esrgan_model.py | 9 ++++----- modules/scunet_model.py | 8 ++++---- modules/shared.py | 9 +++++++-- 4 files changed, 17 insertions(+), 14 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/devices.py b/modules/devices.py index 5d9c7a07..b5a0cd29 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,8 +1,8 @@ import torch -# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility from modules import errors +# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility has_mps = getattr(torch, 'has_mps', False) cpu = torch.device("cpu") @@ -32,8 +32,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = get_optimal_device() -device_gfpgan = device_codeformer = cpu if device.type == 'mps' else device +device = device_gfpgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 def randn(seed, shape): diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 4aed9283..d17e730f 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -6,8 +6,7 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url import modules.esrgam_model_arch as arch -from modules import shared, modelloader, images -from modules.devices import has_mps +from modules import shared, modelloader, images, devices from modules.paths import models_path from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts @@ -97,7 +96,7 @@ class UpscalerESRGAN(Upscaler): model = self.load_model(selected_model) if model is None: return img - model.to(shared.device) + model.to(devices.device_esrgan) img = esrgan_upscale(model, img) return img @@ -112,7 +111,7 @@ class UpscalerESRGAN(Upscaler): print("Unable to load %s from %s" % (self.model_path, filename)) return None - pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None) + pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None) crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32) pretrained_net = fix_model_layers(crt_model, pretrained_net) @@ -127,7 +126,7 @@ def upscale_without_tiling(model, img): img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(shared.device) + img = img.unsqueeze(0).to(devices.device_esrgan) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/scunet_model.py b/modules/scunet_model.py index 7987ac14..fb64b740 100644 --- a/modules/scunet_model.py +++ b/modules/scunet_model.py @@ -8,7 +8,7 @@ import torch from basicsr.utils.download_util import load_file_from_url import modules.upscaler -from modules import shared, modelloader +from modules import devices, modelloader from modules.paths import models_path from modules.scunet_model_arch import SCUNet as net @@ -51,12 +51,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler): if model is None: return img - device = shared.device + device = devices.device_scunet img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(shared.device) + img = img.unsqueeze(0).to(device) img = img.to(device) with torch.no_grad(): @@ -69,7 +69,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): return PIL.Image.fromarray(output, 'RGB') def load_model(self, path: str): - device = shared.device + device = devices.device_scunet if "http" in path: filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, progress=True) diff --git a/modules/shared.py b/modules/shared.py index 2a599e9c..7899ab8d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -12,7 +12,7 @@ import modules.interrogate import modules.memmon import modules.sd_models import modules.styles -from modules.devices import get_optimal_device +import modules.devices as devices from modules.paths import script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -46,6 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") +parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU for specified modules", default=[]) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -63,7 +64,11 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print cmd_opts = parser.parse_args() -device = get_optimal_device() + +devices.device, devices.device_gfpgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ +(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'ESRGAN', 'SCUNet', 'CodeFormer']) + +device = devices.device batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram -- cgit v1.2.3 From 27ddc24fdee1fbe709054a43235ab7f9c51b3e9f Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 4 Oct 2022 05:18:17 -0400 Subject: Add BSRGAN to --add-cpu --- modules/bsrgan_model.py | 6 +++--- modules/devices.py | 2 +- modules/shared.py | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/bsrgan_model.py b/modules/bsrgan_model.py index e62c6657..3bd80791 100644 --- a/modules/bsrgan_model.py +++ b/modules/bsrgan_model.py @@ -8,7 +8,7 @@ import torch from basicsr.utils.download_util import load_file_from_url import modules.upscaler -from modules import shared, modelloader +from modules import devices, modelloader from modules.bsrgan_model_arch import RRDBNet from modules.paths import models_path @@ -44,13 +44,13 @@ class UpscalerBSRGAN(modules.upscaler.Upscaler): model = self.load_model(selected_file) if model is None: return img - model.to(shared.device) + model.to(devices.device_bsrgan) torch.cuda.empty_cache() img = np.array(img) img = img[:, :, ::-1] img = np.moveaxis(img, 2, 0) / 255 img = torch.from_numpy(img).float() - img = img.unsqueeze(0).to(shared.device) + img = img.unsqueeze(0).to(devices.device_bsrgan) with torch.no_grad(): output = model(img) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() diff --git a/modules/devices.py b/modules/devices.py index b5a0cd29..b7899632 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -32,7 +32,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = device_gfpgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() +device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 def randn(seed, shape): diff --git a/modules/shared.py b/modules/shared.py index 7899ab8d..95b98a06 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -46,7 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU for specified modules", default=[]) +parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU for specified modules", default=[]) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -65,8 +65,8 @@ parser.add_argument("--enable-console-prompts", action='store_true', help="print cmd_opts = parser.parse_args() -devices.device, devices.device_gfpgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ -(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'ESRGAN', 'SCUNet', 'CodeFormer']) +devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ +(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer']) device = devices.device -- cgit v1.2.3 From dc9c5a97742e3a34d37da7108642d8adc0dc5858 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 4 Oct 2022 05:22:50 -0400 Subject: Modify --add-cpu description --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 95b98a06..25aff5b0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -46,7 +46,7 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU for specified modules", default=[]) +parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[]) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) -- cgit v1.2.3 From accd00d6b8258c12b5168918a4c546b02357924a Mon Sep 17 00:00:00 2001 From: Justin Riddiough Date: Tue, 4 Oct 2022 01:14:28 -0500 Subject: Explain how to use second progress bar in pycharm --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 25aff5b0..11bdf01a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -200,7 +200,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration" options_templates.update(options_section(('system', "System"), { "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), - "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."), + "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. In PyCharm select 'emulate terminal in console output'."), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { -- cgit v1.2.3 From ea6b0d98a64290a0305e27126ea59ce1da7959a2 Mon Sep 17 00:00:00 2001 From: Justin Riddiough Date: Tue, 4 Oct 2022 06:38:45 -0500 Subject: Remove pycharm note, fix typo --- modules/shared.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 11bdf01a..a7d13b2d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -200,7 +200,7 @@ options_templates.update(options_section(('face-restoration', "Face restoration" options_templates.update(options_section(('system', "System"), { "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), - "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. In PyCharm select 'emulate terminal in console output'."), + "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { @@ -209,7 +209,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), - "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), -- cgit v1.2.3 From 957e29a8e9cb8ca069799ec69263e188c89ed6a6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 17:23:48 +0300 Subject: option to not show images in web ui --- modules/img2img.py | 3 +++ modules/shared.py | 1 + modules/txt2img.py | 3 +++ 3 files changed, 7 insertions(+) (limited to 'modules/shared.py') diff --git a/modules/img2img.py b/modules/img2img.py index 2ff8e261..da212d72 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -129,4 +129,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro if opts.samples_log_stdout: print(generation_info_js) + if opts.do_not_show_images: + processed.images = [] + return processed.images, generation_info_js, plaintext_to_html(processed.info) diff --git a/modules/shared.py b/modules/shared.py index a7d13b2d..ff4e5fa3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -229,6 +229,7 @@ options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "return_grid": OptionInfo(True, "Show grid in results for web"), + "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), diff --git a/modules/txt2img.py b/modules/txt2img.py index d4406c3c..e985242b 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -48,5 +48,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: if opts.samples_log_stdout: print(generation_info_js) + if opts.do_not_show_images: + processed.images = [] + return processed.images, generation_info_js, plaintext_to_html(processed.info) -- cgit v1.2.3 From b32852ef037251eb3d846af76e2965594e1ac7a5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 4 Oct 2022 20:49:54 +0300 Subject: add editor to img2img --- modules/shared.py | 1 + modules/ui.py | 2 +- style.css | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index ff4e5fa3..e52c9b1d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -55,6 +55,7 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="color-sketch") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) diff --git a/modules/ui.py b/modules/ui.py index 20dc8c37..6cd6761b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -644,7 +644,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tabs(elem_id="mode_img2img") as tabs_img2img_mode: with gr.TabItem('img2img', id='img2img'): - init_img = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil") + init_img = gr.Image(label="Image for img2img", elem_id="img2img_image", show_label=False, source="upload", interactive=True, type="pil", tool=cmd_opts.gradio_img2img_tool) with gr.TabItem('Inpaint', id='inpaint'): init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA") diff --git a/style.css b/style.css index 39586bf1..e8f4cb75 100644 --- a/style.css +++ b/style.css @@ -403,3 +403,7 @@ input[type="range"]{ .red { color: red; } + +#img2img_image div.h-60{ + height: 480px; +} \ No newline at end of file -- cgit v1.2.3 From 55400c981b7c1389482057a35ed6ea11f08da194 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 03:11:15 +0100 Subject: Set gradio-img2img-tool default to 'editor' --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index e52c9b1d..bab0fe6e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -55,7 +55,7 @@ parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide dire parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) -parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="color-sketch") +parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image uploader tool: can be either editor for ctopping, or color-sketch for drawing', choices=["color-sketch", "editor"], default="editor") parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) -- cgit v1.2.3 From 5f24b7bcf4a074fbdec757617fcd1bc82e76551b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 12:08:48 +0300 Subject: option to let users select which samplers they want to hide --- modules/processing.py | 13 ++++++------- modules/sd_samplers.py | 19 +++++++++++++++++-- modules/shared.py | 15 +++++++++------ webui.py | 4 +++- 4 files changed, 35 insertions(+), 16 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/processing.py b/modules/processing.py index d8c6b8d5..e01c8b3f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -11,9 +11,8 @@ import cv2 from skimage import exposure import modules.sd_hijack -from modules import devices, prompt_parser, masking +from modules import devices, prompt_parser, masking, sd_samplers from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.face_restoration @@ -110,7 +109,7 @@ class Processed: self.width = p.width self.height = p.height self.sampler_index = p.sampler_index - self.sampler = samplers[p.sampler_index].name + self.sampler = sd_samplers.samplers[p.sampler_index].name self.cfg_scale = p.cfg_scale self.steps = p.steps self.batch_size = p.batch_size @@ -265,7 +264,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params = { "Steps": p.steps, - "Sampler": samplers[p.sampler_index].name, + "Sampler": sd_samplers.samplers[p.sampler_index].name, "CFG scale": p.cfg_scale, "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), @@ -478,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.firstphase_height_truncated = int(scale * self.height) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - self.sampler = samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -521,7 +520,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - self.sampler = samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory @@ -556,7 +555,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model) crop_region = None if self.image_mask is not None: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d27c547b..2e1f7715 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -32,12 +32,27 @@ samplers_data_k_diffusion = [ if hasattr(k_diffusion.sampling, funcname) ] -samplers = [ +all_samplers = [ *samplers_data_k_diffusion, SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), ] -samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']] + +samplers = [] +samplers_for_img2img = [] + + +def set_samplers(): + global samplers, samplers_for_img2img + + hidden = set(opts.hide_samplers) + hidden_img2img = set(opts.hide_samplers + ['PLMS', 'DPM fast', 'DPM adaptive']) + + samplers = [x for x in all_samplers if x.name not in hidden] + samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] + + +set_samplers() sampler_extra_params = { 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], diff --git a/modules/shared.py b/modules/shared.py index bab0fe6e..ca2e4c74 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,6 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices +from modules import sd_samplers from modules.paths import script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -238,14 +239,16 @@ options_templates.update(options_section(('ui', "User interface"), { })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { - "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), - 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "hide_samplers": OptionInfo([], "Hide samplers in user interface (requires restart)", gr.CheckboxGroup, lambda: {"choices": [x.name for x in sd_samplers.all_samplers]}), + "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), + 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) + class Options: data = None data_labels = options_templates diff --git a/webui.py b/webui.py index 47848ba5..9ef12427 100644 --- a/webui.py +++ b/webui.py @@ -2,7 +2,7 @@ import os import threading import time import importlib -from modules import devices +from modules import devices, sd_samplers from modules.paths import script_path import signal import threading @@ -109,6 +109,8 @@ def webui(): time.sleep(0.5) break + sd_samplers.set_samplers() + print('Reloading Custom Scripts') modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) print('Reloading modules: modules.ui') -- cgit v1.2.3 From 3ddf80a9db8793188e2fe9488233d2b272cceb33 Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Wed, 5 Oct 2022 14:31:51 +0300 Subject: add variant setting --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index ca2e4c74..9e4860a2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,6 +236,7 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "show_karras_scheduler_variants": OptionInfo(True, "Show Karras scheduling variants for select samplers. Try these variants if your K sampled images suffer from excessive noise."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From 5993df24a1026225cb8af89237547c1d9101ce69 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 6 Oct 2022 14:12:52 +0300 Subject: integrate the new samplers PR --- modules/processing.py | 7 ++-- modules/sd_samplers.py | 59 +++++++++++++++------------- modules/shared.py | 1 - scripts/alternate_sampler_noise_schedules.py | 53 ------------------------- scripts/img2imgalt.py | 3 +- 5 files changed, 36 insertions(+), 87 deletions(-) delete mode 100644 scripts/alternate_sampler_noise_schedules.py (limited to 'modules/shared.py') diff --git a/modules/processing.py b/modules/processing.py index e01c8b3f..e567956c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -477,7 +477,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.firstphase_height_truncated = int(scale * self.height) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -520,7 +520,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob() - self.sampler = sd_samplers.samplers[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory @@ -555,7 +556,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = sd_samplers.samplers_for_img2img[self.sampler_index].constructor(self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) crop_region = None if self.image_mask is not None: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 8d6eb762..497df943 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -13,46 +13,46 @@ from modules.shared import opts, cmd_opts, state import modules.shared as shared -SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases']) +SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) samplers_k_diffusion = [ - ('Euler a', 'sample_euler_ancestral', ['k_euler_a']), - ('Euler', 'sample_euler', ['k_euler']), - ('LMS', 'sample_lms', ['k_lms']), - ('Heun', 'sample_heun', ['k_heun']), - ('DPM2', 'sample_dpm_2', ['k_dpm_2']), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']), - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']), - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']), + ('Euler a', 'sample_euler_ancestral', ['k_euler_a'], {}), + ('Euler', 'sample_euler', ['k_euler'], {}), + ('LMS', 'sample_lms', ['k_lms'], {}), + ('Heun', 'sample_heun', ['k_heun'], {}), + ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}), + ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), + ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), + ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), + ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), + ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}), + ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), ] -if opts.show_karras_scheduler_variants: - k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2 - k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral - k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms - samplers_k_diffusion_ka = [ - ('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']), - ('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']), - ('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']), - ] - samplers_k_diffusion.extend(samplers_k_diffusion_ka) - samplers_data_k_diffusion = [ - SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases) - for label, funcname, aliases in samplers_k_diffusion + SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_k_diffusion if hasattr(k_diffusion.sampling, funcname) ] all_samplers = [ *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), + SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}), + SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}), ] samplers = [] samplers_for_img2img = [] +def create_sampler_with_index(list_of_configs, index, model): + config = list_of_configs[index] + sampler = config.constructor(model) + sampler.config = config + + return sampler + + def set_samplers(): global samplers, samplers_for_img2img @@ -130,6 +130,7 @@ class VanillaStableDiffusionSampler: self.step = 0 self.eta = None self.default_eta = 0.0 + self.config = None def number_of_needed_noises(self, p): return 0 @@ -291,6 +292,7 @@ class KDiffusionSampler: self.stop_at = None self.eta = None self.default_eta = 1.0 + self.config = None def callback_state(self, d): store_latent(d["denoised"]) @@ -355,11 +357,12 @@ class KDiffusionSampler: steps = steps or p.steps if p.sampler_noise_scheduler_override: - sigmas = p.sampler_noise_scheduler_override(steps) - elif self.funcname.endswith('ka'): - sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) + sigmas = p.sampler_noise_scheduler_override(steps) + elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': + sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device) else: - sigmas = self.model_wrap.get_sigmas(steps) + sigmas = self.model_wrap.get_sigmas(steps) + x = x * sigmas[0] extra_params_kwargs = self.initialize(p) diff --git a/modules/shared.py b/modules/shared.py index 9e4860a2..ca2e4c74 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,7 +236,6 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "show_karras_scheduler_variants": OptionInfo(True, "Show Karras scheduling variants for select samplers. Try these variants if your K sampled images suffer from excessive noise."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { diff --git a/scripts/alternate_sampler_noise_schedules.py b/scripts/alternate_sampler_noise_schedules.py deleted file mode 100644 index 4f3ed8fb..00000000 --- a/scripts/alternate_sampler_noise_schedules.py +++ /dev/null @@ -1,53 +0,0 @@ -import inspect -from modules.processing import Processed, process_images -import gradio as gr -import modules.scripts as scripts -import k_diffusion.sampling -import torch - - -class Script(scripts.Script): - - def title(self): - return "Alternate Sampler Noise Schedules" - - def ui(self, is_img2img): - noise_scheduler = gr.Dropdown(label="Noise Scheduler", choices=['Default','Karras','Exponential', 'Variance Preserving'], value='Default', type="index") - sched_smin = gr.Slider(value=0.1, label="Sigma min", minimum=0.0, maximum=100.0, step=0.5,) - sched_smax = gr.Slider(value=10.0, label="Sigma max", minimum=0.0, maximum=100.0, step=0.5) - sched_rho = gr.Slider(value=7.0, label="Sigma rho (Karras only)", minimum=7.0, maximum=100.0, step=0.5) - sched_beta_d = gr.Slider(value=19.9, label="Beta distribution (VP only)",minimum=0.0, maximum=40.0, step=0.5) - sched_beta_min = gr.Slider(value=0.1, label="Beta min (VP only)", minimum=0.0, maximum=40.0, step=0.1) - sched_eps_s = gr.Slider(value=0.001, label="Epsilon (VP only)", minimum=0.001, maximum=1.0, step=0.001) - - return [noise_scheduler, sched_smin, sched_smax, sched_rho, sched_beta_d, sched_beta_min, sched_eps_s] - - def run(self, p, noise_scheduler, sched_smin, sched_smax, sched_rho, sched_beta_d, sched_beta_min, sched_eps_s): - - noise_scheduler_func_name = ['-','get_sigmas_karras','get_sigmas_exponential','get_sigmas_vp'][noise_scheduler] - - base_params = { - "sigma_min":sched_smin, - "sigma_max":sched_smax, - "rho":sched_rho, - "beta_d":sched_beta_d, - "beta_min":sched_beta_min, - "eps_s":sched_eps_s, - "device":"cuda" if torch.cuda.is_available() else "cpu" - } - - if hasattr(k_diffusion.sampling,noise_scheduler_func_name): - - sigma_func = getattr(k_diffusion.sampling,noise_scheduler_func_name) - sigma_func_kwargs = {} - - for k,v in base_params.items(): - if k in inspect.signature(sigma_func).parameters: - sigma_func_kwargs[k] = v - - def substitute_noise_scheduler(n): - return sigma_func(n,**sigma_func_kwargs) - - p.sampler_noise_scheduler_override = substitute_noise_scheduler - - return process_images(p) diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 0ef137f7..f9894cb0 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -8,7 +8,6 @@ import gradio as gr from modules import processing, shared, sd_samplers, prompt_parser from modules.processing import Processed -from modules.sd_samplers import samplers from modules.shared import opts, cmd_opts, state import torch @@ -159,7 +158,7 @@ class Script(scripts.Script): combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5) - sampler = samplers[p.sampler_index].constructor(p.sd_model) + sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, p.sampler_index, p.sd_model) sigmas = sampler.model_wrap.get_sigmas(p.steps) -- cgit v1.2.3 From be71115b1a1201d04f0e2a11e718fb31cbd26474 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 01:09:44 +0100 Subject: Update shared.py --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index ca2e4c74..9f7c6efe 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,6 +236,7 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "show_progress_in_title": OptionInfo(False, "Show generation progress in window title."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From fec71e4de24b65b0f205a3c071b71651bbcb0dfc Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Thu, 6 Oct 2022 01:35:07 +0100 Subject: Default window title progress updates on --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 9f7c6efe..5c16f025 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,7 +236,7 @@ options_templates.update(options_section(('ui', "User interface"), { "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "show_progress_in_title": OptionInfo(False, "Show generation progress in window title."), + "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From cf7c784fcc0c84a8a4edd8d3aca4dda4c7025c43 Mon Sep 17 00:00:00 2001 From: Milly Date: Fri, 7 Oct 2022 00:19:52 +0900 Subject: Removed duplicate defined models_path Use `modules.paths.models_path` instead `modules.shared.model_path`. --- modules/shared.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 5c16f025..25bb6e6c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,11 +14,10 @@ import modules.sd_models import modules.styles import modules.devices as devices from modules import sd_samplers -from modules.paths import script_path, sd_path +from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file -model_path = os.path.join(script_path, 'models') parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) @@ -36,14 +35,14 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") -parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer')) -parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN')) -parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN')) -parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN')) -parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN')) -parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(model_path, 'ScuNET')) -parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR')) -parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR')) +parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) +parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) +parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) +parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN')) +parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN')) +parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) +parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) +parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.3 From da4ab2707b4cb0611cf181ba248a271d1937433e Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Fri, 7 Oct 2022 05:23:06 +0300 Subject: Update shared.py --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 25bb6e6c..8cc3b2fe 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -43,6 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) +parser.add_argument("--disable-opt-xformers-attention", action='store_true', help="force-disables xformers attention optimization") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.3 From bad7cb29cecac51c5c0f39afec332b007ed73133 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 7 Oct 2022 10:17:52 +0300 Subject: added support for hypernetworks (???) --- modules/hypernetwork.py | 55 ++++++++++++++++++++++++++++++++++++++ modules/sd_hijack_optimizations.py | 17 ++++++++++-- modules/shared.py | 9 ++++++- scripts/xy_grid.py | 10 +++++++ 4 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 modules/hypernetwork.py (limited to 'modules/shared.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py new file mode 100644 index 00000000..9ed1eed9 --- /dev/null +++ b/modules/hypernetwork.py @@ -0,0 +1,55 @@ +import glob +import os +import torch +from modules import devices + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + self.load_state_dict(state_dict, strict=True) + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, filename): + self.filename = filename + self.name = os.path.splitext(os.path.basename(filename))[0] + self.layers = {} + + state_dict = torch.load(filename, map_location='cpu') + for size, sd in state_dict.items(): + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + +def load_hypernetworks(path): + res = {} + + for filename in glob.iglob(path + '**/*.pt', recursive=True): + hn = Hypernetwork(filename) + res[hn.name] = hn + + return res + +def apply(self, x, context=None, mask=None, original=None): + + + if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: + if context.shape[1] == 77 and CrossAttention.noise_cond: + context = context + (torch.randn_like(context) * 0.1) + h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] + k = self.to_k(h_k(context)) + v = self.to_v(h_v(context)) + else: + k = self.to_k(context) + v = self.to_v(context) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfc..d9cca485 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,8 @@ from torch import einsum from ldm.util import default from einops import rearrange +from modules import shared + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): @@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - k_in = self.to_k(context) * self.scale - v_in = self.to_v(context) + + hypernetwork = shared.selected_hypernetwork() + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k_in = self.to_k(hypernetwork_layers[0](context)) + v_in = self.to_v(hypernetwork_layers[1](context)) + else: + k_in = self.to_k(context) + v_in = self.to_v(context) + + k_in *= self.scale + del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) diff --git a/modules/shared.py b/modules/shared.py index 25bb6e6c..879d8424 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers +from modules import sd_samplers, hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -76,6 +76,12 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram config_filename = cmd_opts.ui_settings_file +hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) + + +def selected_hypernetwork(): + return hypernetworks.get(opts.sd_hypernetwork, None) + class State: interrupted = False @@ -206,6 +212,7 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6344e612..c0c364df 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -77,6 +77,11 @@ def apply_checkpoint(p, x, xs): modules.sd_models.reload_model_weights(shared.sd_model, info) +def apply_hypernetwork(p, x, xs): + hn = shared.hypernetworks.get(x, None) + opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -122,6 +127,7 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value), AxisOption("Checkpoint name", str, apply_checkpoint, format_value), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), @@ -193,6 +199,8 @@ class Script(scripts.Script): modules.processing.fix_seed(p) p.batch_size = 1 + initial_hn = opts.sd_hypernetwork + def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -300,4 +308,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) + opts.data["sd_hypernetwork"] = initial_hn + return processed -- cgit v1.2.3 From 706d5944a075a6523ea7f00165d630efc085ca22 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 13:38:57 +0300 Subject: let user choose his own prompt token count limit --- modules/processing.py | 6 ++++++ modules/sd_hijack.py | 13 +++++++------ modules/shared.py | 5 +++-- 3 files changed, 16 insertions(+), 8 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/processing.py b/modules/processing.py index f773a30e..d814d5ac 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -123,6 +123,7 @@ class Processed: self.index_of_first_image = index_of_first_image self.styles = p.styles self.job_timestamp = state.job_timestamp + self.max_prompt_tokens = opts.max_prompt_tokens self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -141,6 +142,7 @@ class Processed: self.all_subseeds = all_subseeds or [self.subseed] self.infotexts = infotexts or [info] + def js(self): obj = { "prompt": self.prompt, @@ -169,6 +171,7 @@ class Processed: "infotexts": self.infotexts, "styles": self.styles, "job_timestamp": self.job_timestamp, + "max_prompt_tokens": self.max_prompt_tokens, } return json.dumps(obj) @@ -266,6 +269,8 @@ def fix_seed(p): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size + max_tokens = getattr(p, 'max_prompt_tokens', opts.max_prompt_tokens) + generation_params = { "Steps": p.steps, "Sampler": sd_samplers.samplers[p.sampler_index].name, @@ -281,6 +286,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), + "Max tokens": (None if max_tokens == shared.vanilla_max_prompt_tokens else max_tokens) } generation_params.update(p.extra_generation_params) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index d68f89cc..340329c0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -18,7 +18,6 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward - def apply_optimizations(): undo_optimizations() @@ -83,7 +82,7 @@ class StableDiffusionModelHijack: layer.padding_mode = 'circular' if enable else 'zeros' def tokenize(self, text): - max_length = self.clip.max_length - 2 + max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, max_length @@ -94,7 +93,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.wrapped = wrapped self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer - self.max_length = wrapped.max_length self.token_mults = {} tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] @@ -116,7 +114,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length + maxlen = opts.max_prompt_tokens if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) @@ -191,7 +189,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = self.wrapped.max_length + maxlen = self.wrapped.max_length # you get to stay at 77 used_custom_terms = [] remade_batch_tokens = [] overflowing_words = [] @@ -268,8 +266,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if len(used_custom_terms) > 0: self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + position_ids_array = [min(x, 75) for x in range(len(remade_batch_tokens[0])-1)] + [76] + position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1)) + tokens = torch.asarray(remade_batch_tokens).to(device) - outputs = self.wrapped.transformer(input_ids=tokens) + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) z = outputs.last_hidden_state # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise diff --git a/modules/shared.py b/modules/shared.py index 879d8424..864e772c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -118,8 +118,8 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename) interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] -# This was moved to webui.py with the other model "setup" calls. -# modules.sd_models.list_models() + +vanilla_max_prompt_tokens = 77 def realesrgan_models_names(): @@ -221,6 +221,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), + "max_prompt_tokens": OptionInfo(vanilla_max_prompt_tokens, f"Max prompt token count. Two tokens are reserved for for start and end. Default is {vanilla_max_prompt_tokens}. Setting this to a different value will result in different pictures for same seed.", gr.Number, {"precision": 0}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.3 From 786d9f63aaa4515df82eb2cf357ea92f3dae1e29 Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Tue, 4 Oct 2022 22:56:30 -0500 Subject: Add button to skip the current iteration --- javascript/hints.js | 1 + javascript/progressbar.js | 20 ++++++++++++++------ modules/img2img.py | 4 ++++ modules/processing.py | 4 ++++ modules/shared.py | 5 +++++ modules/ui.py | 8 ++++++++ style.css | 14 ++++++++++++-- webui.py | 1 + 8 files changed, 49 insertions(+), 8 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/hints.js b/javascript/hints.js index 8adcd983..8e352e94 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -35,6 +35,7 @@ titles = { "Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.", "Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.", + "Skip": "Stop processing current image and continue processing.", "Interrupt": "Stop processing images and return any results accumulated so far.", "Save": "Write image to a directory (default - log/images) and generation parameters into csv file.", diff --git a/javascript/progressbar.js b/javascript/progressbar.js index f9e9290e..4395a215 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,8 +1,9 @@ // code related to showing and updating progressbar shown as the image is being made global_progressbars = {} -function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){ +function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ var progressbar = gradioApp().getElementById(id_progressbar) + var skip = id_skip ? gradioApp().getElementById(id_skip) : null var interrupt = gradioApp().getElementById(id_interrupt) if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ @@ -32,30 +33,37 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; if(!progressDiv){ + if (skip) { + skip.style.display = "none" + } interrupt.style.display = "none" } } - window.setTimeout(function(){ requestMoreProgress(id_part, id_progressbar_span, id_interrupt) }, 500) + window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500) }); mutationObserver.observe( progressbar, { childList:true, subtree:true }) } } onUiUpdate(function(){ - check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') - check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') - check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery') + check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') + check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') + check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery') }) -function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){ +function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){ btn = gradioApp().getElementById(id_part+"_check_progress"); if(btn==null) return; btn.click(); var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; + var skip = id_skip ? gradioApp().getElementById(id_skip) : null var interrupt = gradioApp().getElementById(id_interrupt) if(progressDiv && interrupt){ + if (skip) { + skip.style.display = "block" + } interrupt.style.display = "block" } } diff --git a/modules/img2img.py b/modules/img2img.py index da212d72..e60b7e0f 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -32,6 +32,10 @@ def process_batch(p, input_dir, output_dir, args): for i, image in enumerate(images): state.job = f"{i+1} out of {len(images)}" + if state.skipped: + state.skipped = False + state.interrupted = False + continue if state.interrupted: break diff --git a/modules/processing.py b/modules/processing.py index d814d5ac..6805039c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -355,6 +355,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: state.job_count = p.n_iter for n in range(p.n_iter): + if state.skipped: + state.skipped = False + state.interrupted = False + if state.interrupted: break diff --git a/modules/shared.py b/modules/shared.py index 864e772c..7f802bd9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -84,6 +84,7 @@ def selected_hypernetwork(): class State: + skipped = False interrupted = False job = "" job_no = 0 @@ -96,6 +97,10 @@ class State: current_image_sampling_step = 0 textinfo = None + def skip(self): + self.skipped = True + self.interrupted = True + def interrupt(self): self.interrupted = True diff --git a/modules/ui.py b/modules/ui.py index 4f18126f..e3e62fdd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -191,6 +191,7 @@ def wrap_gradio_call(func, extra_outputs=None): # last item is always HTML res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" + shared.state.skipped = False shared.state.interrupted = False shared.state.job_count = 0 @@ -411,9 +412,16 @@ def create_toprow(is_img2img): with gr.Column(scale=1): with gr.Row(): + skip = gr.Button('Skip', elem_id=f"{id_part}_skip") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) + interrupt.click( fn=lambda: shared.state.interrupt(), inputs=[], diff --git a/style.css b/style.css index 50c5e557..6904fc50 100644 --- a/style.css +++ b/style.css @@ -393,10 +393,20 @@ input[type="range"]{ #txt2img_interrupt, #img2img_interrupt{ position: absolute; - width: 100%; + width: 50%; height: 72px; background: #b4c0cc; - border-radius: 8px; + border-radius: 0px; + display: none; +} + +#txt2img_skip, #img2img_skip{ + position: absolute; + width: 50%; + right: 0px; + height: 72px; + background: #b4c0cc; + border-radius: 0px; display: none; } diff --git a/webui.py b/webui.py index 480360fe..3b4cf5e9 100644 --- a/webui.py +++ b/webui.py @@ -58,6 +58,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): shared.state.current_latent = None shared.state.current_image = None shared.state.current_image_sampling_step = 0 + shared.state.skipped = False shared.state.interrupted = False shared.state.textinfo = None -- cgit v1.2.3 From 00117a07efbbe8482add12262a179326541467de Mon Sep 17 00:00:00 2001 From: Trung Ngo Date: Sat, 8 Oct 2022 05:33:21 -0500 Subject: check specifically for skipped --- modules/img2img.py | 2 -- modules/processing.py | 3 +-- modules/sd_samplers.py | 4 ++-- modules/shared.py | 1 - 4 files changed, 3 insertions(+), 7 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/img2img.py b/modules/img2img.py index e60b7e0f..24126774 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -34,8 +34,6 @@ def process_batch(p, input_dir, output_dir, args): state.job = f"{i+1} out of {len(images)}" if state.skipped: state.skipped = False - state.interrupted = False - continue if state.interrupted: break diff --git a/modules/processing.py b/modules/processing.py index 6805039c..3657fe69 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -357,7 +357,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: for n in range(p.n_iter): if state.skipped: state.skipped = False - state.interrupted = False if state.interrupted: break @@ -385,7 +384,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: with devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) - if state.interrupted: + if state.interrupted or state.skipped: # if we are interruped, sample returns just noise # use the image collected previously in sampler loop diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index df17e93c..13a8b322 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -106,7 +106,7 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs): seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs) for x in seq: - if state.interrupted: + if state.interrupted or state.skipped: break yield x @@ -254,7 +254,7 @@ def extended_trange(sampler, count, *args, **kwargs): seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs) for x in seq: - if state.interrupted: + if state.interrupted or state.skipped: break if sampler.stop_at is not None and x > sampler.stop_at: diff --git a/modules/shared.py b/modules/shared.py index 7f802bd9..ca462628 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -99,7 +99,6 @@ class State: def skip(self): self.skipped = True - self.interrupted = True def interrupt(self): self.interrupted = True -- cgit v1.2.3 From 4999eb2ef9b30e8c42ca7e4a94d4bbffe4d1f015 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 14:25:47 +0300 Subject: do not let user choose his own prompt token count limit --- README.md | 1 + modules/processing.py | 5 ----- modules/sd_hijack.py | 25 ++++++++++++------------- modules/shared.py | 3 --- 4 files changed, 13 insertions(+), 21 deletions(-) (limited to 'modules/shared.py') diff --git a/README.md b/README.md index d6e1d50b..ef9b5e31 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once - separate prompts using uppercase `AND` - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` +- No token limit for prompts (original stable diffusion lets you use up to 75 tokens) ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/modules/processing.py b/modules/processing.py index 3657fe69..d5162ddc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -123,7 +123,6 @@ class Processed: self.index_of_first_image = index_of_first_image self.styles = p.styles self.job_timestamp = state.job_timestamp - self.max_prompt_tokens = opts.max_prompt_tokens self.eta = p.eta self.ddim_discretize = p.ddim_discretize @@ -171,7 +170,6 @@ class Processed: "infotexts": self.infotexts, "styles": self.styles, "job_timestamp": self.job_timestamp, - "max_prompt_tokens": self.max_prompt_tokens, } return json.dumps(obj) @@ -269,8 +267,6 @@ def fix_seed(p): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size - max_tokens = getattr(p, 'max_prompt_tokens', opts.max_prompt_tokens) - generation_params = { "Steps": p.steps, "Sampler": sd_samplers.samplers[p.sampler_index].name, @@ -286,7 +282,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), - "Max tokens": (None if max_tokens == shared.vanilla_max_prompt_tokens else max_tokens) } generation_params.update(p.extra_generation_params) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 340329c0..2c1332c9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -36,6 +36,13 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward +def get_target_prompt_token_count(token_count): + if token_count < 75: + return 75 + + return math.ceil(token_count / 10) * 10 + + class StableDiffusionModelHijack: fixes = None comments = [] @@ -84,7 +91,7 @@ class StableDiffusionModelHijack: def tokenize(self, text): max_length = opts.max_prompt_tokens - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) - return remade_batch_tokens[0], token_count, max_length + return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): @@ -114,7 +121,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def tokenize_line(self, line, used_custom_terms, hijack_comments): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id - maxlen = opts.max_prompt_tokens if opts.enable_emphasis: parsed = prompt_parser.parse_prompt_attention(line) @@ -146,19 +152,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): used_custom_terms.append((embedding.name, embedding.checksum())) i += embedding_length_in_tokens - if len(remade_tokens) > maxlen - 2: - vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} - ovf = remade_tokens[maxlen - 2:] - overflowing_words = [vocab.get(int(x), "") for x in ovf] - overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - token_count = len(remade_tokens) - remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + prompt_target_length = get_target_prompt_token_count(token_count) + tokens_to_add = prompt_target_length - len(remade_tokens) + 1 - multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) - multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add + multipliers = [1.0] + multipliers + [1.0] * tokens_to_add return remade_tokens, fixes, multipliers, token_count diff --git a/modules/shared.py b/modules/shared.py index ca462628..475d7e52 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -123,8 +123,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] -vanilla_max_prompt_tokens = 77 - def realesrgan_models_names(): import modules.realesrgan_model @@ -225,7 +223,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), - "max_prompt_tokens": OptionInfo(vanilla_max_prompt_tokens, f"Max prompt token count. Two tokens are reserved for for start and end. Default is {vanilla_max_prompt_tokens}. Setting this to a different value will result in different pictures for same seed.", gr.Number, {"precision": 0}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.3 From ddfa9a97865c732193023a71521c5b7b53d8571b Mon Sep 17 00:00:00 2001 From: C43H66N12O12S2 <36072735+C43H66N12O12S2@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:20:41 +0300 Subject: add xformers_available shared variable --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 8cc3b2fe..6ed4b802 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,7 @@ device = devices.device batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram - +xformers_available = False config_filename = cmd_opts.ui_settings_file -- cgit v1.2.3 From dc1117233ef8f9b25ff1ac40b158f20b70ba2fcb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 17:02:18 +0300 Subject: simplify xfrmers options: --xformers to enable and that's it --- launch.py | 2 +- modules/sd_hijack.py | 2 +- modules/sd_hijack_optimizations.py | 20 +++++++++++++------- modules/shared.py | 2 +- 4 files changed, 16 insertions(+), 10 deletions(-) (limited to 'modules/shared.py') diff --git a/launch.py b/launch.py index a592e1ba..61f62096 100644 --- a/launch.py +++ b/launch.py @@ -32,7 +32,7 @@ def extract_arg(args, name): args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') -args, xformers = extract_arg(args, '--xformers') +xformers = '--xformers' in args def repo_dir(name): diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5d93f7f6..91e98c16 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,7 +22,7 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if not cmd_opts.disable_opt_xformers_attention and not (cmd_opts.opt_split_attention or torch.version.hip) and shared.xformers_available: + if cmd_opts.xformers and shared.xformers_available and not torch.version.hip: ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 05023b6f..d23d733b 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,4 +1,7 @@ import math +import sys +import traceback + import torch from torch import einsum @@ -7,13 +10,16 @@ from einops import rearrange from modules import shared -try: - import xformers.ops - import functorch - xformers._is_functorch_available = True - shared.xformers_available = True -except Exception: - print('Cannot find xformers, defaulting to split attention. Try adding --xformers commandline argument to your webui-user file if you wish to install it.') +if shared.cmd_opts.xformers: + try: + import xformers.ops + import functorch + xformers._is_functorch_available = True + shared.xformers_available = True + except Exception: + print("Cannot import xformers", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion def split_cross_attention_forward_v1(self, x, context=None, mask=None): diff --git a/modules/shared.py b/modules/shared.py index d68df751..02cb2722 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -43,7 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) -parser.add_argument("--disable-opt-xformers-attention", action='store_true', help="force-disables xformers attention optimization") +parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.3 From 01f8cb44474e454903c11718e6a4f33dbde34bb8 Mon Sep 17 00:00:00 2001 From: Greendayle Date: Sat, 8 Oct 2022 18:02:56 +0200 Subject: made deepdanbooru optional, added to readme, automatic download of deepbooru model --- README.md | 2 ++ launch.py | 4 ++++ modules/deepbooru.py | 20 ++++++++++---------- modules/shared.py | 1 + modules/ui.py | 19 ++++++++++++------- requirements.txt | 3 --- requirements_versions.txt | 3 --- 7 files changed, 29 insertions(+), 23 deletions(-) (limited to 'modules/shared.py') diff --git a/README.md b/README.md index ef9b5e31..6cd7a1f9 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - separate prompts using uppercase `AND` - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` - No token limit for prompts (original stable diffusion lets you use up to 75 tokens) +- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. @@ -123,4 +124,5 @@ The documentation was moved from this README over to the project's [wiki](https: - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. +- DeepDanbooru - interrogator for anime diffusors https://github.com/KichangKim/DeepDanbooru - (You) diff --git a/launch.py b/launch.py index 61f62096..d46426eb 100644 --- a/launch.py +++ b/launch.py @@ -33,6 +33,7 @@ def extract_arg(args, name): args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') xformers = '--xformers' in args +deepdanbooru = '--deepdanbooru' in args def repo_dir(name): @@ -132,6 +133,9 @@ if not is_installed("xformers") and xformers: elif platform.system() == "Linux": run_pip("install xformers", "xformers") +if not is_installed("deepdanbooru") and deepdanbooru: + run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru") + os.makedirs(dir_repos, exist_ok=True) git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 781b2249..7e3c0618 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -9,16 +9,16 @@ def _load_tf_and_return_tags(pil_image, threshold): import numpy as np this_folder = os.path.dirname(__file__) - model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28') - - model_good = False - for path_candidate in [model_path, os.path.dirname(model_path)]: - if os.path.exists(os.path.join(path_candidate, 'project.json')): - model_path = path_candidate - model_good = True - if not model_good: - return ("Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/" - "deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru") + model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru')) + if not os.path.exists(os.path.join(model_path, 'project.json')): + # there is no point importing these every time + import zipfile + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip", + model_path) + with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref: + zip_ref.extractall(model_path) + os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip")) tags = dd.project.load_tags_from_project(model_path) model = dd.project.load_model_from_project( diff --git a/modules/shared.py b/modules/shared.py index 02cb2722..c87b726e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -44,6 +44,7 @@ parser.add_argument("--scunet-models-path", type=str, help="Path to directory wi parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") +parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") diff --git a/modules/ui.py b/modules/ui.py index 30583fe9..c5c11c3c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,9 +23,10 @@ import gradio.utils import gradio.routes from modules import sd_hijack -from modules.deepbooru import get_deepbooru_tags from modules.paths import script_path from modules.shared import opts, cmd_opts +if cmd_opts.deepdanbooru: + from modules.deepbooru import get_deepbooru_tags import modules.shared as shared from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_hijack import model_hijack @@ -437,7 +438,10 @@ def create_toprow(is_img2img): with gr.Row(scale=1): if is_img2img: interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + if cmd_opts.deepdanbooru: + deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + else: + deepbooru = None else: interrogate = None deepbooru = None @@ -782,11 +786,12 @@ def create_ui(wrap_gradio_gpu_call): outputs=[img2img_prompt], ) - img2img_deepbooru.click( - fn=interrogate_deepbooru, - inputs=[init_img], - outputs=[img2img_prompt], - ) + if cmd_opts.deepdanbooru: + img2img_deepbooru.click( + fn=interrogate_deepbooru, + inputs=[init_img], + outputs=[img2img_prompt], + ) save.click( fn=wrap_gradio_call(save_files), diff --git a/requirements.txt b/requirements.txt index cd3953c6..81641d68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,4 @@ resize-right torchdiffeq kornia lark -deepdanbooru -tensorflow -tensorflow-io functorch diff --git a/requirements_versions.txt b/requirements_versions.txt index 2d256a54..fec3e9d5 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -22,7 +22,4 @@ resize-right==0.0.2 torchdiffeq==0.2.3 kornia==0.6.7 lark==1.1.2 -git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] -tensorflow==2.10.0 -tensorflow-io==0.27.0 functorch==0.2.1 -- cgit v1.2.3 From 3061cdb7b610d4ba7f1ea695d9d6364b591e5bc7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 8 Oct 2022 19:22:15 +0300 Subject: add --force-enable-xformers option and also add messages to console regarding cross attention optimizations --- modules/sd_hijack.py | 6 +++++- modules/shared.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index a3e374f0..307cc67d 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -22,12 +22,16 @@ def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6): + + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and torch.cuda.get_device_capability(shared.device) == (8, 6)): + print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: + print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): + print("Applying cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward diff --git a/modules/shared.py b/modules/shared.py index 02cb2722..8f941226 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -44,6 +44,7 @@ parser.add_argument("--scunet-models-path", type=str, help="Path to directory wi parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") +parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") -- cgit v1.2.3 From 1371d7608b402d6f15c200ec2f5fde4579836a05 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 14:28:22 -0400 Subject: Added ability to ignore last n layers in FrozenCLIPEmbedder --- modules/sd_hijack.py | 11 +++++++++-- modules/shared.py | 1 + 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 307cc67d..f12a9696 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -281,8 +281,15 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) - outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) - z = outputs.last_hidden_state + + tmp = -opts.CLIP_ignore_last_layers + if (opts.CLIP_ignore_last_layers == 0): + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) + z = outputs.last_hidden_state + else: + outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp) + z = outputs.hidden_states[tmp] + z = self.wrapped.transformer.text_model.final_layer_norm(z) # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] diff --git a/modules/shared.py b/modules/shared.py index 8f941226..af8dc744 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -225,6 +225,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), + 'CLIP_ignore_last_layers': OptionInfo(0, "Ignore last layers of CLIP model", gr.Slider, {"minimum": 0, "maximum": 5, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.3 From 432782163ae53e605470bcefc9a6f796c4556912 Mon Sep 17 00:00:00 2001 From: Aidan Holland Date: Sat, 8 Oct 2022 15:12:24 -0400 Subject: chore: Fix typos --- README.md | 2 +- javascript/imageviewer.js | 2 +- modules/interrogate.py | 4 ++-- modules/processing.py | 2 +- modules/scunet_model_arch.py | 4 ++-- modules/sd_models.py | 4 ++-- modules/sd_samplers.py | 4 ++-- modules/shared.py | 6 +++--- modules/swinir_model_arch.py | 2 +- modules/ui.py | 4 ++-- 10 files changed, 17 insertions(+), 17 deletions(-) (limited to 'modules/shared.py') diff --git a/README.md b/README.md index ef9b5e31..63dd0c18 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - Sampling method selection - Interrupt processing at any time - 4GB video card support (also reports of 2GB working) -- Correct seeds for batches +- Correct seeds for batches - Prompt length validation - get length of prompt in tokens as you type - get a warning after generation if some text was truncated diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 4c0e8f4b..6a00c0da 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -95,7 +95,7 @@ function showGalleryImage(){ e.addEventListener('click', function (evt) { if(!opts.js_modal_lightbox) return; - modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initialy_zoomed) + modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) showModal(evt) },true); } diff --git a/modules/interrogate.py b/modules/interrogate.py index eed87144..635e266e 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -140,11 +140,11 @@ class InterrogateModels: res = caption - cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) + clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext with torch.no_grad(), precision_scope("cuda"): - image_features = self.clip_model.encode_image(cilp_image).type(self.dtype) + image_features = self.clip_model.encode_image(clip_image).type(self.dtype) image_features /= image_features.norm(dim=-1, keepdim=True) diff --git a/modules/processing.py b/modules/processing.py index 515fc91a..31220881 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -386,7 +386,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.interrupted or state.skipped: - # if we are interruped, sample returns just noise + # if we are interrupted, sample returns just noise # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent diff --git a/modules/scunet_model_arch.py b/modules/scunet_model_arch.py index 972a2639..43ca8d36 100644 --- a/modules/scunet_model_arch.py +++ b/modules/scunet_model_arch.py @@ -40,7 +40,7 @@ class WMSA(nn.Module): Returns: attn_mask: should be (1 1 w p p), """ - # supporting sqaure. + # supporting square. attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device) if self.type == 'W': return attn_mask @@ -65,7 +65,7 @@ class WMSA(nn.Module): x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size) h_windows = x.size(1) w_windows = x.size(2) - # sqaure validation + # square validation # assert h_windows == w_windows x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size) diff --git a/modules/sd_models.py b/modules/sd_models.py index 9409d070..a09866ce 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,7 +147,7 @@ def load_model_weights(model, checkpoint_file, sd_model_hash): model.first_stage_model.load_state_dict(vae_dict) model.sd_model_hash = sd_model_hash - model.sd_model_checkpint = checkpoint_file + model.sd_model_checkpoint = checkpoint_file def load_model(): @@ -175,7 +175,7 @@ def reload_model_weights(sd_model, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - if sd_model.sd_model_checkpint == checkpoint_info.filename: + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index eade0dbb..6e743f7e 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -181,7 +181,7 @@ class VanillaStableDiffusionSampler: self.initialize(p) - # existing code fails with cetain step counts, like 9 + # existing code fails with certain step counts, like 9 try: self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) except Exception: @@ -204,7 +204,7 @@ class VanillaStableDiffusionSampler: steps = steps or p.steps - # existing code fails with cetin step counts, like 9 + # existing code fails with certain step counts, like 9 try: samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta) except Exception: diff --git a/modules/shared.py b/modules/shared.py index af8dc744..2dc092d6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -141,9 +141,9 @@ class OptionInfo: self.section = None -def options_section(section_identifer, options_dict): +def options_section(section_identifier, options_dict): for k, v in options_dict.items(): - v.section = section_identifer + v.section = section_identifier return options_dict @@ -246,7 +246,7 @@ options_templates.update(options_section(('ui', "User interface"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), - "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), })) diff --git a/modules/swinir_model_arch.py b/modules/swinir_model_arch.py index 461fb354..863f42db 100644 --- a/modules/swinir_model_arch.py +++ b/modules/swinir_model_arch.py @@ -166,7 +166,7 @@ class SwinTransformerBlock(nn.Module): Args: dim (int): Number of input channels. - input_resolution (tuple[int]): Input resulotion. + input_resolution (tuple[int]): Input resolution. num_heads (int): Number of attention heads. window_size (int): Window size. shift_size (int): Shift size for SW-MSA. diff --git a/modules/ui.py b/modules/ui.py index b09359aa..b51af121 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -38,7 +38,7 @@ from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui -# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI +# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') @@ -102,7 +102,7 @@ def save_files(js_data, images, index): import csv filenames = [] - #quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it + #quick dictionary to class object conversion. Its necessary due apply_filename_pattern requiring it class MyObject: def __init__(self, d=None): if d is not None: -- cgit v1.2.3 From 122d42687b97ec4df4c2a8c335d2de385cd1f1a1 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 22:37:35 -0400 Subject: Fix VRAM Issue by only loading in hypernetwork when selected in settings --- modules/hypernetwork.py | 23 +++++++++++++++-------- modules/sd_hijack_optimizations.py | 6 +++--- modules/shared.py | 7 ++----- webui.py | 3 +++ 4 files changed, 23 insertions(+), 16 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 7f062242..19f1c227 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -40,18 +40,25 @@ class Hypernetwork: self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) -def load_hypernetworks(path): +def list_hypernetworks(path): res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res + + +def load_hypernetwork(filename): + print(f"Loading hypernetwork {filename}") + path = shared.hypernetworks.get(filename, None) + if (path is not None): try: - hn = Hypernetwork(filename) - res[hn.name] = hn + shared.loaded_hypernetwork = Hypernetwork(path) except Exception: - print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - - return res + else: + shared.loaded_hypernetwork = None def attention_CrossAttention_forward(self, x, context=None, mask=None): @@ -60,7 +67,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index c4396bb9..634fb4b2 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -28,7 +28,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: @@ -68,7 +68,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: @@ -132,7 +132,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) - hypernetwork = shared.selected_hypernetwork() + hypernetwork = shared.loaded_hypernetwork hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) if hypernetwork_layers is not None: k_in = self.to_k(hypernetwork_layers[0](context)) diff --git a/modules/shared.py b/modules/shared.py index b2c76a32..9dce6cb7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -79,11 +79,8 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file -hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) - - -def selected_hypernetwork(): - return hypernetworks.get(opts.sd_hypernetwork, None) +hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) +loaded_hypernetwork = None class State: diff --git a/webui.py b/webui.py index 18de8e16..270584f7 100644 --- a/webui.py +++ b/webui.py @@ -82,6 +82,9 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts")) shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) +loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) +shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) + def webui(): # make the program just exit at ctrl+c without waiting for anything -- cgit v1.2.3 From 1ffeb42d38d9276dc28918189d32f60d593a162c Mon Sep 17 00:00:00 2001 From: Nicolas Noullet Date: Sun, 9 Oct 2022 00:18:45 +0200 Subject: Fix typo --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 9dce6cb7..dffa0094 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -238,7 +238,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), + "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), -- cgit v1.2.3 From d6d10a37bfd21568e74efb46137f906da96d5fdb Mon Sep 17 00:00:00 2001 From: William Moorehouse Date: Sun, 9 Oct 2022 04:58:40 -0400 Subject: Added extended model details to infotext --- modules/processing.py | 3 +++ modules/sd_models.py | 3 ++- modules/shared.py | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/processing.py b/modules/processing.py index 2c991317..d1bcee4a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -284,6 +284,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), + "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_name else shared.sd_model.sd_model_name), + "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name), + "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), diff --git a/modules/sd_models.py b/modules/sd_models.py index d0c74dd8..3fa42329 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,7 @@ import sys from collections import namedtuple import torch from omegaconf import OmegaConf - +from pathlib import Path from ldm.util import instantiate_from_config @@ -158,6 +158,7 @@ def load_model_weights(model, checkpoint_info): vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict) + model.sd_model_vae_name = Path(vae_file).stem model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index dffa0094..ca63f7d8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -242,6 +242,7 @@ options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), + "add_extended_model_details_to_info": OptionInfo(False, "Add extended model details to generation information (model name, VAE, hypernetwork)"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), -- cgit v1.2.3 From e6e8cabe0c9c335e0d72345602c069b198558b53 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:57:48 +0300 Subject: change up #2056 to make it work how i want it to plus make xy plot write correct values to images --- modules/processing.py | 5 ++--- modules/sd_models.py | 2 -- modules/shared.py | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/processing.py b/modules/processing.py index 049f3769..04aed989 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -284,9 +284,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Model VAE": (None if not opts.add_extended_model_details_to_info or not shared.sd_model.sd_model_vae_name else shared.sd_model.sd_model_vae_name.replace(',', '').replace(':', '')), - "Model hypernetwork": (None if not opts.add_extended_model_details_to_info or not opts.sd_hypernetwork else opts.sd_hypernetwork.replace(',', '').replace(':', '')), + "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), + "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), diff --git a/modules/sd_models.py b/modules/sd_models.py index 3fa42329..e63d3c29 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -4,7 +4,6 @@ import sys from collections import namedtuple import torch from omegaconf import OmegaConf -from pathlib import Path from ldm.util import instantiate_from_config @@ -158,7 +157,6 @@ def load_model_weights(model, checkpoint_info): vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} model.first_stage_model.load_state_dict(vae_dict) - model.sd_model_vae_name = Path(vae_file).stem model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file diff --git a/modules/shared.py b/modules/shared.py index ca63f7d8..6ecc2503 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -242,7 +242,7 @@ options_templates.update(options_section(('ui', "User interface"), { "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "add_extended_model_details_to_info": OptionInfo(False, "Add extended model details to generation information (model name, VAE, hypernetwork)"), + "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), -- cgit v1.2.3 From 875ddfeecfaffad9eee24813301637cba310337d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 17:58:43 +0300 Subject: added guard for torch.load to prevent loading pickles with unknown content --- modules/paths.py | 1 + modules/safe.py | 89 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + 3 files changed, 91 insertions(+) create mode 100644 modules/safe.py (limited to 'modules/shared.py') diff --git a/modules/paths.py b/modules/paths.py index 0519caa0..1e7a2fbc 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -1,6 +1,7 @@ import argparse import os import sys +import modules.safe script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) models_path = os.path.join(script_path, "models") diff --git a/modules/safe.py b/modules/safe.py new file mode 100644 index 00000000..2d2c1371 --- /dev/null +++ b/modules/safe.py @@ -0,0 +1,89 @@ +# this code is adapted from the script contributed by anon from /h/ + +import io +import pickle +import collections +import sys +import traceback + +import torch +import numpy +import _codecs +import zipfile + + +def encode(*args): + out = _codecs.encode(*args) + return out + + +class RestrictedUnpickler(pickle.Unpickler): + def persistent_load(self, saved_id): + assert saved_id[0] == 'storage' + return torch.storage._TypedStorage() + + def find_class(self, module, name): + if module == 'collections' and name == 'OrderedDict': + return getattr(collections, name) + if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: + return getattr(torch._utils, name) + if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage']: + return getattr(torch, name) + if module == 'torch.nn.modules.container' and name in ['ParameterDict']: + return getattr(torch.nn.modules.container, name) + if module == 'numpy.core.multiarray' and name == 'scalar': + return numpy.core.multiarray.scalar + if module == 'numpy' and name == 'dtype': + return numpy.dtype + if module == '_codecs' and name == 'encode': + return encode + if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': + import pytorch_lightning.callbacks + return pytorch_lightning.callbacks.model_checkpoint + if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': + import pytorch_lightning.callbacks.model_checkpoint + return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint + if module == "__builtin__" and name == 'set': + return set + + # Forbid everything else. + raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") + + +def check_pt(filename): + try: + + # new pytorch format is a zip file + with zipfile.ZipFile(filename) as z: + with z.open('archive/data.pkl') as file: + unpickler = RestrictedUnpickler(file) + unpickler.load() + + except zipfile.BadZipfile: + + # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle + with open(filename, "rb") as file: + unpickler = RestrictedUnpickler(file) + for i in range(5): + unpickler.load() + + +def load(filename, *args, **kwargs): + from modules import shared + + try: + if not shared.cmd_opts.disable_safe_unpickle: + check_pt(filename) + + except Exception: + print(f"Error verifying pickled file from {filename}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) + print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr) + return None + + return unsafe_torch_load(filename, *args, **kwargs) + + +unsafe_torch_load = torch.load +torch.load = load diff --git a/modules/shared.py b/modules/shared.py index 6ecc2503..3d7f08e1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -65,6 +65,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) +parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) cmd_opts = parser.parse_args() -- cgit v1.2.3 From 6c383d2e82045fc4475d665f83bdeeac8fd844d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 22:24:07 +0300 Subject: show model selection setting on top of page --- modules/shared.py | 5 +++-- modules/ui.py | 54 +++++++++++++++++++++++++++++++++++++++++++++--------- style.css | 9 +++++++++ 3 files changed, 57 insertions(+), 11 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 3d7f08e1..270fa402 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -131,13 +131,14 @@ def realesrgan_models_names(): class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange self.section = None + self.show_on_main_page = show_on_main_page def options_section(section_identifier, options_dict): @@ -214,7 +215,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), diff --git a/modules/ui.py b/modules/ui.py index dad509f3..2231a8ed 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1175,10 +1175,13 @@ Requested path was: {f} changed = 0 for key, value, comp in zip(opts.data_labels.keys(), args, components): - if not opts.same_type(value, opts.data_labels[key].default): - return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default): + return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson() for key, value, comp in zip(opts.data_labels.keys(), args, components): + if comp == dummy_component: + continue + comp_args = opts.data_labels[key].component_args if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: continue @@ -1196,6 +1199,21 @@ Requested path was: {f} return f'{changed} settings changed.', opts.dumpjson() + def run_settings_single(value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + oldval = opts.data.get(key, None) + opts.data[key] = value + + if oldval != value: + if opts.data_labels[key].onchange is not None: + opts.data_labels[key].onchange() + + opts.save(shared.config_filename) + + return gr.update(value=value), opts.dumpjson() + with gr.Blocks(analytics_enabled=False) as settings_interface: settings_submit = gr.Button(value="Apply settings", variant='primary') result = gr.HTML() @@ -1203,6 +1221,8 @@ Requested path was: {f} settings_cols = 3 items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) + quicksettings_list = [] + cols_displayed = 0 items_displayed = 0 previous_section = None @@ -1225,10 +1245,14 @@ Requested path was: {f} gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - items_displayed += 1 + if item.show_on_main_page: + quicksettings_list.append((i, k, item)) + components.append(dummy_component) + else: + component = create_setting_component(k) + component_dict[k] = component + components.append(component) + items_displayed += 1 request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") request_notifications.click( @@ -1242,7 +1266,6 @@ Requested path was: {f} reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - def reload_scripts(): modules.scripts.reload_script_body_only() @@ -1289,7 +1312,11 @@ Requested path was: {f} css += css_hide_progressbar with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: - + with gr.Row(elem_id="quicksettings"): + for i, k, item in quicksettings_list: + component = create_setting_component(k) + component_dict[k] = component + settings_interface.gradio_ref = demo with gr.Tabs() as tabs: @@ -1306,7 +1333,16 @@ Requested path was: {f} inputs=components, outputs=[result, text_settings], ) - + + for i, k, item in quicksettings_list: + component = component_dict[k] + + component.change( + fn=lambda value, k=k: run_settings_single(value, key=k), + inputs=[component], + outputs=[component, text_settings], + ) + def modelmerger(*args): try: results = modules.extras.run_modelmerger(*args) diff --git a/style.css b/style.css index 101d2052..28160bdf 100644 --- a/style.css +++ b/style.css @@ -453,3 +453,12 @@ input[type="range"]{ .context-menu-items a:hover{ background: #a55000; } + +#quicksettings > div{ + border: none; +} + +#quicksettings > div > div{ + max-width: 32em; + padding: 0; +} -- cgit v1.2.3 From a14f7bf113a2af9e06a1c4d06c2efa244f9c5730 Mon Sep 17 00:00:00 2001 From: Fampai Date: Sat, 8 Oct 2022 16:33:06 -0400 Subject: Corrected CLIP Layer Ignore description and updated its range to the max possible --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 270fa402..1995a99a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -225,7 +225,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), - 'CLIP_ignore_last_layers': OptionInfo(0, "Ignore last layers of CLIP model", gr.Slider, {"minimum": 0, "maximum": 5, "step": 1}), + 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -- cgit v1.2.3 From cc92dc1f8d73dd4d574c4c8ccab78b7fc61e440b Mon Sep 17 00:00:00 2001 From: ssysm Date: Sun, 9 Oct 2022 23:17:29 -0400 Subject: add vae path args --- modules/sd_models.py | 2 +- modules/shared.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index cb3982b1..b6979432 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,7 +147,7 @@ def load_model_weights(model, checkpoint_info): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + vae_file = shared.cmd_opts.vae_path or os.path.splitext(checkpoint_file)[0] + ".vae.pt" if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location="cpu") diff --git a/modules/shared.py b/modules/shared.py index 2dc092d6..52ccfa6e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -64,7 +64,7 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) - +parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) cmd_opts = parser.parse_args() -- cgit v1.2.3 From 7349088d32b080f64058b6e5de5f0380a71ecd09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 16:11:14 +0300 Subject: --no-half-vae --- modules/devices.py | 6 +++++- modules/processing.py | 11 +++++++++-- modules/sd_models.py | 3 +++ modules/sd_samplers.py | 4 ++-- modules/shared.py | 1 + 5 files changed, 20 insertions(+), 5 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/devices.py b/modules/devices.py index 0158b11f..03ef58f1 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -36,6 +36,7 @@ errors.run(enable_tf32, "Enabling TF32") device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 +dtype_vae = torch.float16 def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. @@ -59,9 +60,12 @@ def randn_without_seed(shape): return torch.randn(shape, device=device) -def autocast(): +def autocast(disable=False): from modules import shared + if disable: + return contextlib.nullcontext() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() diff --git a/modules/processing.py b/modules/processing.py index 94d2dd62..ec8651ae 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -259,6 +259,13 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see return x +def decode_first_stage(model, x): + with devices.autocast(disable=x.dtype == devices.dtype_vae): + x = model.decode_first_stage(x) + + return x + + def get_fixed_seed(seed): if seed is None or seed == '' or seed == -1: return int(random.randrange(4294967294)) @@ -400,7 +407,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: samples_ddim = samples_ddim.to(devices.dtype) - x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) + x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) del samples_ddim @@ -533,7 +540,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.scale_latent: samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") else: - decoded_samples = self.sd_model.decode_first_stage(samples) + decoded_samples = decode_first_stage(self.sd_model, samples) if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") diff --git a/modules/sd_models.py b/modules/sd_models.py index e63d3c29..2cdcd84f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -149,6 +149,7 @@ def load_model_weights(model, checkpoint_info): model.half() devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" if os.path.exists(vae_file): @@ -158,6 +159,8 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.load_state_dict(vae_dict) + model.first_stage_model.to(devices.dtype_vae) + model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 6e743f7e..d168b938 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,7 +7,7 @@ import inspect import k_diffusion.sampling import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser +from modules import prompt_parser, devices, processing from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -83,7 +83,7 @@ def setup_img2img_steps(p, steps=None): def sample_to_image(samples): - x_sample = shared.sd_model.decode_first_stage(samples[0:1].type(shared.sd_model.dtype))[0] + x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0] x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) diff --git a/modules/shared.py b/modules/shared.py index 1995a99a..5dfc344c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -25,6 +25,7 @@ parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to director parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") +parser.add_argument("--no-half-vae", action='store_true', help="do not switch the VAE model to 16-bit floats") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") -- cgit v1.2.3 From 39919c40dd18f5a14ae21403efea1b0f819756c7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 20:32:37 +0300 Subject: add eta noise seed delta option --- javascript/hints.js | 1 + modules/processing.py | 6 +++++- modules/shared.py | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/javascript/hints.js b/javascript/hints.js index 8e352e94..47b80776 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -79,6 +79,7 @@ titles = { "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", + "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", } diff --git a/modules/processing.py b/modules/processing.py index 50ba4fc5..698b3069 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -207,7 +207,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see # enables the generation of additional tensors with noise that the sampler will use during its processing. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and len(seeds) > 1 and opts.enable_batch_seeds: + if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None @@ -247,6 +247,9 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see if sampler_noises is not None: cnt = p.sampler.number_of_needed_noises(p) + if opts.eta_noise_seed_delta > 0: + torch.manual_seed(seed + opts.eta_noise_seed_delta) + for j in range(cnt): sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) @@ -301,6 +304,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, + "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta, } generation_params.update(p.extra_generation_params) diff --git a/modules/shared.py b/modules/shared.py index 5dfc344c..b1c65ecf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -260,6 +260,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -- cgit v1.2.3 From f98338faa84ecce503e68d8ba13d5f7bbae52730 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 10 Oct 2022 23:15:48 +0300 Subject: add an option to not add watermark to created images --- javascript/hints.js | 1 + modules/shared.py | 1 + 2 files changed, 2 insertions(+) (limited to 'modules/shared.py') diff --git a/javascript/hints.js b/javascript/hints.js index 47b80776..045f2d3c 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -80,6 +80,7 @@ titles = { "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", + "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be bevaing in an unethical manner.", } diff --git a/modules/shared.py b/modules/shared.py index da389f9c..ecd15ef5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -173,6 +173,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"), "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), + "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"), })) options_templates.update(options_section(('saving-paths', "Paths for saving"), { -- cgit v1.2.3 From 76ef3d75f61253516c024553335d9083d9660a8a Mon Sep 17 00:00:00 2001 From: JC_Array Date: Mon, 10 Oct 2022 18:01:49 -0500 Subject: added deepbooru settings (threshold and sort by alpha or likelyhood) --- modules/deepbooru.py | 36 +++++++++++++++++++++++++----------- modules/shared.py | 6 ++++++ 2 files changed, 31 insertions(+), 11 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index ebdba5e0..e31e92c0 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -3,31 +3,32 @@ from concurrent.futures import ProcessPoolExecutor import multiprocessing import time - -def get_deepbooru_tags(pil_image, threshold=0.5): +def get_deepbooru_tags(pil_image): """ This method is for running only one image at a time for simple use. Used to the img2img interrogate. """ from modules import shared # prevents circular reference - create_deepbooru_process(threshold) + create_deepbooru_process(shared.opts.deepbooru_threshold, shared.opts.deepbooru_sort_alpha) shared.deepbooru_process_return["value"] = -1 shared.deepbooru_process_queue.put(pil_image) while shared.deepbooru_process_return["value"] == -1: time.sleep(0.2) + tags = shared.deepbooru_process_return["value"] release_process() + return tags -def deepbooru_process(queue, deepbooru_process_return, threshold): +def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort): model, tags = get_deepbooru_tags_model() while True: # while process is running, keep monitoring queue for new image pil_image = queue.get() if pil_image == "QUIT": break else: - deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold) + deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort) -def create_deepbooru_process(threshold=0.5): +def create_deepbooru_process(threshold, alpha_sort): """ Creates deepbooru process. A queue is created to send images into the process. This enables multiple images to be processed in a row without reloading the model or creating a new process. To return the data, a shared @@ -40,7 +41,7 @@ def create_deepbooru_process(threshold=0.5): shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold)) + shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort)) shared.deepbooru_process.start() @@ -80,7 +81,7 @@ def get_deepbooru_tags_model(): return model, tags -def get_deepbooru_tags_from_model(model, tags, pil_image, threshold=0.5): +def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort): import deepdanbooru as dd import tensorflow as tf import numpy as np @@ -105,15 +106,28 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold=0.5): for i, tag in enumerate(tags): result_dict[tag] = y[i] - result_tags_out = [] + + unsorted_tags_in_theshold = [] result_tags_print = [] for tag in tags: if result_dict[tag] >= threshold: if tag.startswith("rating:"): continue - result_tags_out.append(tag) + unsorted_tags_in_theshold.append((result_dict[tag], tag)) result_tags_print.append(f'{result_dict[tag]} {tag}') + # sort tags + result_tags_out = [] + sort_ndx = 0 + print(alpha_sort) + if alpha_sort: + sort_ndx = 1 + + # sort by reverse by likelihood and normal for alpha + unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) + for weight, tag in unsorted_tags_in_theshold: + result_tags_out.append(tag) + print('\n'.join(sorted(result_tags_print, reverse=True))) - return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') \ No newline at end of file + return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') diff --git a/modules/shared.py b/modules/shared.py index 1995a99a..2e307809 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -261,6 +261,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) +if cmd_opts.deepdanbooru: + options_templates.update(options_section(('deepbooru-params', "DeepBooru parameters"), { + "deepbooru_sort_alpha": OptionInfo(True, "Sort Alphabetical", gr.Checkbox), + 'deepbooru_threshold': OptionInfo(0.5, "Threshold", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + })) + class Options: data = None -- cgit v1.2.3 From 8617396c6df71074c7fd3d39419802026874712a Mon Sep 17 00:00:00 2001 From: Kenneth Date: Mon, 10 Oct 2022 17:23:07 -0600 Subject: Added slider for deepbooru score threshold in settings --- modules/shared.py | 1 + modules/ui.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index ecd15ef5..e0830e28 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -239,6 +239,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"), + "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), })) options_templates.update(options_section(('ui', "User interface"), { diff --git a/modules/ui.py b/modules/ui.py index cafda884..ca3151c4 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -311,7 +311,7 @@ def interrogate(image): def interrogate_deepbooru(image): - prompt = get_deepbooru_tags(image) + prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold) return gr_show(True) if prompt is None else prompt -- cgit v1.2.3 From 5e2627a1a63e4c9f87e6e604ecc24e9936f149de Mon Sep 17 00:00:00 2001 From: hentailord85ez <112723046+hentailord85ez@users.noreply.github.com> Date: Tue, 11 Oct 2022 07:55:28 +0100 Subject: Comma backtrack padding (#2192) Comma backtrack padding --- modules/sd_hijack.py | 19 ++++++++++++++++++- modules/shared.py | 1 + 2 files changed, 19 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 827bf304..aa4d2cbc 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -107,6 +107,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.tokenizer = wrapped.tokenizer self.token_mults = {} + self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 @@ -136,6 +138,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): fixes = [] remade_tokens = [] multipliers = [] + last_comma = -1 for tokens, (text, weight) in zip(tokenized, parsed): i = 0 @@ -144,6 +147,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + if token == self.comma_token: + last_comma = len(remade_tokens) + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + last_comma += 1 + reloc_tokens = remade_tokens[last_comma:] + reloc_mults = multipliers[last_comma:] + + remade_tokens = remade_tokens[:last_comma] + length = len(remade_tokens) + + rem = int(math.ceil(length / 75)) * 75 - length + remade_tokens += [id_end] * rem + reloc_tokens + multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults + if embedding is None: remade_tokens.append(token) multipliers.append(weight) @@ -284,7 +301,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while max(map(len, remade_batch_tokens)) != 0: rem_tokens = [x[75:] for x in remade_batch_tokens] rem_multipliers = [x[75:] for x in batch_multipliers] - + self.hijack.fixes = [] for unfiltered in hijack_fixes: fixes = [] diff --git a/modules/shared.py b/modules/shared.py index e0830e28..14b40d70 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -227,6 +227,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), + "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), -- cgit v1.2.3 From 530103b586109c11fd068eb70ef09503ec6a4caf Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 14:53:02 +0300 Subject: fixes related to merge --- modules/hypernetwork.py | 103 ------------------------- modules/hypernetwork/hypernetwork.py | 74 +++++++++++------- modules/hypernetwork/ui.py | 10 +-- modules/sd_hijack_optimizations.py | 3 +- modules/shared.py | 13 +++- modules/textual_inversion/textual_inversion.py | 12 +-- modules/ui.py | 5 +- scripts/xy_grid.py | 3 +- webui.py | 15 +--- 9 files changed, 78 insertions(+), 160 deletions(-) delete mode 100644 modules/hypernetwork.py (limited to 'modules/shared.py') diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py deleted file mode 100644 index 7bbc443e..00000000 --- a/modules/hypernetwork.py +++ /dev/null @@ -1,103 +0,0 @@ -import glob -import os -import sys -import traceback - -import torch - -from ldm.util import default -from modules import devices, shared -import torch -from torch import einsum -from einops import rearrange, repeat - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, state_dict): - super().__init__() - - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) - - self.load_state_dict(state_dict, strict=True) - self.to(devices.device) - - def forward(self, x): - return x + (self.linear2(self.linear1(x))) - - -class Hypernetwork: - filename = None - name = None - - def __init__(self, filename): - self.filename = filename - self.name = os.path.splitext(os.path.basename(filename))[0] - self.layers = {} - - state_dict = torch.load(filename, map_location='cpu') - for size, sd in state_dict.items(): - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) - - -def list_hypernetworks(path): - res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): - name = os.path.splitext(os.path.basename(filename))[0] - res[name] = filename - return res - - -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - if path is not None: - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork(path) - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print(f"Unloading hypernetwork") - - shared.loaded_hypernetwork = None - - -def apply_hypernetwork(hypernetwork, context): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - - -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context) - k = self.to_k(context_k) - v = self.to_v(context_v) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if mask is not None: - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py index a3d6a47e..aa701bda 100644 --- a/modules/hypernetwork/hypernetwork.py +++ b/modules/hypernetwork/hypernetwork.py @@ -26,10 +26,11 @@ class HypernetworkModule(torch.nn.Module): if state_dict is not None: self.load_state_dict(state_dict, strict=True) else: - self.linear1.weight.data.fill_(0.0001) - self.linear1.bias.data.fill_(0.0001) - self.linear2.weight.data.fill_(0.0001) - self.linear2.bias.data.fill_(0.0001) + + self.linear1.weight.data.normal_(mean=0.0, std=0.01) + self.linear1.bias.data.zero_() + self.linear2.weight.data.normal_(mean=0.0, std=0.01) + self.linear2.bias.data.zero_() self.to(devices.device) @@ -92,41 +93,54 @@ class Hypernetwork: self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) -def load_hypernetworks(path): +def list_hypernetworks(path): res = {} + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res - for filename in glob.iglob(path + '**/*.pt', recursive=True): + +def load_hypernetwork(filename): + path = shared.hypernetworks.get(filename, None) + if path is not None: + print(f"Loading hypernetwork {filename}") try: - hn = Hypernetwork() - hn.load(filename) - res[hn.name] = hn + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + except Exception: - print(f"Error loading hypernetwork {filename}", file=sys.stderr) + print(f"Error loading hypernetwork {path}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") - return res + shared.loaded_hypernetwork = None -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads +def apply_hypernetwork(hypernetwork, context, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - q = self.to_q(x) - context = default(context, x) + if hypernetwork_layers is None: + return context, context - hypernetwork_layers = (shared.hypernetwork.layers if shared.hypernetwork is not None else {}).get(context.shape[2], None) + if layer is not None: + layer.hyper_k = hypernetwork_layers[0] + layer.hyper_v = hypernetwork_layers[1] - if hypernetwork_layers is not None: - hypernetwork_k, hypernetwork_v = hypernetwork_layers + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v - self.hypernetwork_k = hypernetwork_k - self.hypernetwork_v = hypernetwork_v - context_k = hypernetwork_k(context) - context_v = hypernetwork_v(context) - else: - context_k = context - context_v = context +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) k = self.to_k(context_k) v = self.to_v(context_v) @@ -151,7 +165,9 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): assert hypernetwork_name, 'embedding not selected' - shared.hypernetwork = shared.hypernetworks[hypernetwork_name] + path = shared.hypernetworks.get(hypernetwork_name, None) + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) shared.state.textinfo = "Initializing hypernetwork training..." shared.state.job_count = steps @@ -176,9 +192,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) - hypernetwork = shared.hypernetworks[hypernetwork_name] + hypernetwork = shared.loaded_hypernetwork weights = hypernetwork.weights() for weight in weights: weight.requires_grad = True @@ -194,7 +210,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps: return hypernetwork, filename - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) for i, (x, text) in pbar: hypernetwork.step = i + ititial_step diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py index 525f978c..f6d1d0a3 100644 --- a/modules/hypernetwork/ui.py +++ b/modules/hypernetwork/ui.py @@ -6,24 +6,24 @@ import gradio as gr import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess from modules import sd_hijack, shared +from modules.hypernetwork import hypernetwork def create_hypernetwork(name): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernetwork = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) - hypernetwork.save(fn) + hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet.save(fn) shared.reload_hypernetworks() - shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" def train_hypernetwork(*args): - initial_hypernetwork = shared.hypernetwork + initial_hypernetwork = shared.loaded_hypernetwork try: sd_hijack.undo_optimizations() @@ -38,6 +38,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception: raise finally: - shared.hypernetwork = initial_hypernetwork + shared.loaded_hypernetwork = initial_hypernetwork sd_hijack.apply_optimizations() diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 25cb67a4..27e571fc 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -8,7 +8,8 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, hypernetwork +from modules import shared +from modules.hypernetwork import hypernetwork if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: diff --git a/modules/shared.py b/modules/shared.py index 14b40d70..8753015e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,8 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, hypernetwork +from modules import sd_samplers +from modules.hypernetwork import hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') @@ -29,6 +30,7 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") @@ -82,10 +84,17 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file -hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks')) +hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None +def reload_hypernetworks(): + global hypernetworks + + hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + + class State: skipped = False interrupted = False diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5965c5a0..d6977950 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -156,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -238,12 +238,14 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') + preview_text = text if preview_image_prompt == "" else preview_image_prompt + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, - prompt=text, + prompt=preview_text, steps=20, - height=training_height, - width=training_width, + height=training_height, + width=training_width, do_not_save_grid=True, do_not_save_samples=True, ) @@ -254,7 +256,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.current_image = image image.save(last_saved_image) - last_saved_image += f", prompt: {text}" + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step diff --git a/modules/ui.py b/modules/ui.py index 10b1ee3a..df653059 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1023,7 +1023,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_embedding = gr.Button(value="Create", variant='primary') + create_embedding = gr.Button(value="Create embedding", variant='primary') with gr.Group(): gr.HTML(value="

Create a new hypernetwork

") @@ -1035,7 +1035,7 @@ def create_ui(wrap_gradio_gpu_call): gr.HTML(value="") with gr.Column(): - create_hypernetwork = gr.Button(value="Create", variant='primary') + create_hypernetwork = gr.Button(value="Create hypernetwork", variant='primary') with gr.Group(): gr.HTML(value="

Preprocess images

") @@ -1147,6 +1147,7 @@ def create_ui(wrap_gradio_gpu_call): create_image_every, save_embedding_every, template_file, + preview_image_prompt, ], outputs=[ ti_output, diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 42e1489c..0af5993c 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,8 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images, hypernetwork +from modules import images +from modules.hypernetwork import hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared diff --git a/webui.py b/webui.py index 7c200551..ba2156c8 100644 --- a/webui.py +++ b/webui.py @@ -29,6 +29,7 @@ from modules import devices from modules import modelloader from modules.paths import script_path from modules.shared import cmd_opts +import modules.hypernetwork.hypernetwork modelloader.cleanup_models() modules.sd_models.setup_model() @@ -77,22 +78,12 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) -def set_hypernetwork(): - shared.hypernetwork = shared.hypernetworks.get(shared.opts.sd_hypernetwork, None) - - -shared.reload_hypernetworks() -shared.opts.onchange("sd_hypernetwork", set_hypernetwork) -set_hypernetwork() - - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) -loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork) -shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) +shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) def webui(): @@ -117,7 +108,7 @@ def webui(): prevent_thread_lock=True ) - app.add_middleware(GZipMiddleware,minimum_size=1000) + app.add_middleware(GZipMiddleware, minimum_size=1000) while 1: time.sleep(0.5) -- cgit v1.2.3 From 59925644480b6fd84f6bb84b4df7d4fbc6a0cce8 Mon Sep 17 00:00:00 2001 From: JamnedZ Date: Tue, 11 Oct 2022 16:40:27 +0700 Subject: Cleaned ngrok integration --- modules/ngrok.py | 15 +++++++++++++++ modules/shared.py | 1 + modules/ui.py | 5 +++++ 3 files changed, 21 insertions(+) create mode 100644 modules/ngrok.py (limited to 'modules/shared.py') diff --git a/modules/ngrok.py b/modules/ngrok.py new file mode 100644 index 00000000..17e6976f --- /dev/null +++ b/modules/ngrok.py @@ -0,0 +1,15 @@ +from pyngrok import ngrok, conf, exception + + +def connect(token, port): + if token == None: + token = 'None' + conf.get_default().auth_token = token + try: + public_url = ngrok.connect(port).public_url + except exception.PyngrokNgrokError: + print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' + f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') + else: + print(f'ngrok connected to localhost:{port}! URL: {public_url}\n' + 'You can use this link after the launch is complete.') \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index 8753015e..375e3afb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -38,6 +38,7 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") +parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) diff --git a/modules/ui.py b/modules/ui.py index fc0f3d3c..f57f32db 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -51,6 +51,11 @@ if not cmd_opts.share and not cmd_opts.listen: gradio.utils.version_check = lambda: None gradio.utils.get_local_ip_address = lambda: '127.0.0.1' +if cmd_opts.ngrok != None: + import modules.ngrok as ngrok + print('ngrok authtoken detected, trying to connect...') + ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860) + def gr_show(visible=True): return {"visible": visible, "__type__": "update"} -- cgit v1.2.3 From 873efeed49bb5197a42da18272115b326c5d68f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 15:51:22 +0300 Subject: rename hypernetwork dir to hypernetworks to prevent clash with an old filename that people who use zip instead of git clone will have --- modules/hypernetwork/hypernetwork.py | 283 ---------------------------------- modules/hypernetwork/ui.py | 43 ------ modules/hypernetworks/hypernetwork.py | 283 ++++++++++++++++++++++++++++++++++ modules/hypernetworks/ui.py | 43 ++++++ modules/sd_hijack.py | 2 +- modules/sd_hijack_optimizations.py | 2 +- modules/shared.py | 2 +- modules/ui.py | 2 +- scripts/xy_grid.py | 2 +- webui.py | 2 +- 10 files changed, 332 insertions(+), 332 deletions(-) delete mode 100644 modules/hypernetwork/hypernetwork.py delete mode 100644 modules/hypernetwork/ui.py create mode 100644 modules/hypernetworks/hypernetwork.py create mode 100644 modules/hypernetworks/ui.py (limited to 'modules/shared.py') diff --git a/modules/hypernetwork/hypernetwork.py b/modules/hypernetwork/hypernetwork.py deleted file mode 100644 index aa701bda..00000000 --- a/modules/hypernetwork/hypernetwork.py +++ /dev/null @@ -1,283 +0,0 @@ -import datetime -import glob -import html -import os -import sys -import traceback -import tqdm - -import torch - -from ldm.util import default -from modules import devices, shared, processing, sd_models -import torch -from torch import einsum -from einops import rearrange, repeat -import modules.textual_inversion.dataset - - -class HypernetworkModule(torch.nn.Module): - def __init__(self, dim, state_dict=None): - super().__init__() - - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) - - if state_dict is not None: - self.load_state_dict(state_dict, strict=True) - else: - - self.linear1.weight.data.normal_(mean=0.0, std=0.01) - self.linear1.bias.data.zero_() - self.linear2.weight.data.normal_(mean=0.0, std=0.01) - self.linear2.bias.data.zero_() - - self.to(devices.device) - - def forward(self, x): - return x + (self.linear2(self.linear1(x))) - - -class Hypernetwork: - filename = None - name = None - - def __init__(self, name=None): - self.filename = None - self.name = name - self.layers = {} - self.step = 0 - self.sd_checkpoint = None - self.sd_checkpoint_name = None - - for size in [320, 640, 768, 1280]: - self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) - - def weights(self): - res = [] - - for k, layers in self.layers.items(): - for layer in layers: - layer.train() - res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] - - return res - - def save(self, filename): - state_dict = {} - - for k, v in self.layers.items(): - state_dict[k] = (v[0].state_dict(), v[1].state_dict()) - - state_dict['step'] = self.step - state_dict['name'] = self.name - state_dict['sd_checkpoint'] = self.sd_checkpoint - state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name - - torch.save(state_dict, filename) - - def load(self, filename): - self.filename = filename - if self.name is None: - self.name = os.path.splitext(os.path.basename(filename))[0] - - state_dict = torch.load(filename, map_location='cpu') - - for size, sd in state_dict.items(): - if type(size) == int: - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) - - self.name = state_dict.get('name', self.name) - self.step = state_dict.get('step', 0) - self.sd_checkpoint = state_dict.get('sd_checkpoint', None) - self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) - - -def list_hypernetworks(path): - res = {} - for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): - name = os.path.splitext(os.path.basename(filename))[0] - res[name] = filename - return res - - -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - if path is not None: - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) - - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print(f"Unloading hypernetwork") - - shared.loaded_hypernetwork = None - - -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - if layer is not None: - layer.hyper_k = hypernetwork_layers[0] - layer.hyper_v = hypernetwork_layers[1] - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - - -def attention_CrossAttention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) - k = self.to_k(context_k) - v = self.to_v(context_v) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - if mask is not None: - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) - - -def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): - assert hypernetwork_name, 'embedding not selected' - - path = shared.hypernetworks.get(hypernetwork_name, None) - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) - - shared.state.textinfo = "Initializing hypernetwork training..." - shared.state.job_count = steps - - filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') - - log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) - - if save_hypernetwork_every > 0: - hypernetwork_dir = os.path.join(log_directory, "hypernetworks") - os.makedirs(hypernetwork_dir, exist_ok=True) - else: - hypernetwork_dir = None - - if create_image_every > 0: - images_dir = os.path.join(log_directory, "images") - os.makedirs(images_dir, exist_ok=True) - else: - images_dir = None - - cond_model = shared.sd_model.cond_stage_model - - shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." - with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) - - hypernetwork = shared.loaded_hypernetwork - weights = hypernetwork.weights() - for weight in weights: - weight.requires_grad = True - - optimizer = torch.optim.AdamW(weights, lr=learn_rate) - - losses = torch.zeros((32,)) - - last_saved_file = "" - last_saved_image = "" - - ititial_step = hypernetwork.step or 0 - if ititial_step > steps: - return hypernetwork, filename - - pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text) in pbar: - hypernetwork.step = i + ititial_step - - if hypernetwork.step > steps: - break - - if shared.state.interrupted: - break - - with torch.autocast("cuda"): - c = cond_model([text]) - - x = x.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] - del x - - losses[hypernetwork.step % losses.shape[0]] = loss.item() - - optimizer.zero_grad() - loss.backward() - optimizer.step() - - pbar.set_description(f"loss: {losses.mean():.7f}") - - if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') - hypernetwork.save(last_saved_file) - - if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: - last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - - preview_text = text if preview_image_prompt == "" else preview_image_prompt - - p = processing.StableDiffusionProcessingTxt2Img( - sd_model=shared.sd_model, - prompt=preview_text, - steps=20, - do_not_save_grid=True, - do_not_save_samples=True, - ) - - processed = processing.process_images(p) - image = processed.images[0] - - shared.state.current_image = image - image.save(last_saved_image) - - last_saved_image += f", prompt: {preview_text}" - - shared.state.job_no = hypernetwork.step - - shared.state.textinfo = f""" -

-Loss: {losses.mean():.7f}
-Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
-Last saved embedding: {html.escape(last_saved_file)}
-Last saved image: {html.escape(last_saved_image)}
-

-""" - - checkpoint = sd_models.select_checkpoint() - - hypernetwork.sd_checkpoint = checkpoint.hash - hypernetwork.sd_checkpoint_name = checkpoint.model_name - hypernetwork.save(filename) - - return hypernetwork, filename - - diff --git a/modules/hypernetwork/ui.py b/modules/hypernetwork/ui.py deleted file mode 100644 index f6d1d0a3..00000000 --- a/modules/hypernetwork/ui.py +++ /dev/null @@ -1,43 +0,0 @@ -import html -import os - -import gradio as gr - -import modules.textual_inversion.textual_inversion -import modules.textual_inversion.preprocess -from modules import sd_hijack, shared -from modules.hypernetwork import hypernetwork - - -def create_hypernetwork(name): - fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") - assert not os.path.exists(fn), f"file {fn} already exists" - - hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) - hypernet.save(fn) - - shared.reload_hypernetworks() - - return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" - - -def train_hypernetwork(*args): - - initial_hypernetwork = shared.loaded_hypernetwork - - try: - sd_hijack.undo_optimizations() - - hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) - - res = f""" -Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. -Hypernetwork saved to {html.escape(filename)} -""" - return res, "" - except Exception: - raise - finally: - shared.loaded_hypernetwork = initial_hypernetwork - sd_hijack.apply_optimizations() - diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py new file mode 100644 index 00000000..aa701bda --- /dev/null +++ b/modules/hypernetworks/hypernetwork.py @@ -0,0 +1,283 @@ +import datetime +import glob +import html +import os +import sys +import traceback +import tqdm + +import torch + +from ldm.util import default +from modules import devices, shared, processing, sd_models +import torch +from torch import einsum +from einops import rearrange, repeat +import modules.textual_inversion.dataset + + +class HypernetworkModule(torch.nn.Module): + def __init__(self, dim, state_dict=None): + super().__init__() + + self.linear1 = torch.nn.Linear(dim, dim * 2) + self.linear2 = torch.nn.Linear(dim * 2, dim) + + if state_dict is not None: + self.load_state_dict(state_dict, strict=True) + else: + + self.linear1.weight.data.normal_(mean=0.0, std=0.01) + self.linear1.bias.data.zero_() + self.linear2.weight.data.normal_(mean=0.0, std=0.01) + self.linear2.bias.data.zero_() + + self.to(devices.device) + + def forward(self, x): + return x + (self.linear2(self.linear1(x))) + + +class Hypernetwork: + filename = None + name = None + + def __init__(self, name=None): + self.filename = None + self.name = name + self.layers = {} + self.step = 0 + self.sd_checkpoint = None + self.sd_checkpoint_name = None + + for size in [320, 640, 768, 1280]: + self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) + + def weights(self): + res = [] + + for k, layers in self.layers.items(): + for layer in layers: + layer.train() + res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] + + return res + + def save(self, filename): + state_dict = {} + + for k, v in self.layers.items(): + state_dict[k] = (v[0].state_dict(), v[1].state_dict()) + + state_dict['step'] = self.step + state_dict['name'] = self.name + state_dict['sd_checkpoint'] = self.sd_checkpoint + state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name + + torch.save(state_dict, filename) + + def load(self, filename): + self.filename = filename + if self.name is None: + self.name = os.path.splitext(os.path.basename(filename))[0] + + state_dict = torch.load(filename, map_location='cpu') + + for size, sd in state_dict.items(): + if type(size) == int: + self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + + self.name = state_dict.get('name', self.name) + self.step = state_dict.get('step', 0) + self.sd_checkpoint = state_dict.get('sd_checkpoint', None) + self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) + + +def list_hypernetworks(path): + res = {} + for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): + name = os.path.splitext(os.path.basename(filename))[0] + res[name] = filename + return res + + +def load_hypernetwork(filename): + path = shared.hypernetworks.get(filename, None) + if path is not None: + print(f"Loading hypernetwork {filename}") + try: + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + + except Exception: + print(f"Error loading hypernetwork {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + else: + if shared.loaded_hypernetwork is not None: + print(f"Unloading hypernetwork") + + shared.loaded_hypernetwork = None + + +def apply_hypernetwork(hypernetwork, context, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is None: + return context, context + + if layer is not None: + layer.hyper_k = hypernetwork_layers[0] + layer.hyper_v = hypernetwork_layers[1] + + context_k = hypernetwork_layers[0](context) + context_v = hypernetwork_layers[1](context) + return context_k, context_v + + +def attention_CrossAttention_forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) + k = self.to_k(context_k) + v = self.to_v(context_v) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): + assert hypernetwork_name, 'embedding not selected' + + path = shared.hypernetworks.get(hypernetwork_name, None) + shared.loaded_hypernetwork = Hypernetwork() + shared.loaded_hypernetwork.load(path) + + shared.state.textinfo = "Initializing hypernetwork training..." + shared.state.job_count = steps + + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + + log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + + if save_hypernetwork_every > 0: + hypernetwork_dir = os.path.join(log_directory, "hypernetworks") + os.makedirs(hypernetwork_dir, exist_ok=True) + else: + hypernetwork_dir = None + + if create_image_every > 0: + images_dir = os.path.join(log_directory, "images") + os.makedirs(images_dir, exist_ok=True) + else: + images_dir = None + + cond_model = shared.sd_model.cond_stage_model + + shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." + with torch.autocast("cuda"): + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + + hypernetwork = shared.loaded_hypernetwork + weights = hypernetwork.weights() + for weight in weights: + weight.requires_grad = True + + optimizer = torch.optim.AdamW(weights, lr=learn_rate) + + losses = torch.zeros((32,)) + + last_saved_file = "" + last_saved_image = "" + + ititial_step = hypernetwork.step or 0 + if ititial_step > steps: + return hypernetwork, filename + + pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) + for i, (x, text) in pbar: + hypernetwork.step = i + ititial_step + + if hypernetwork.step > steps: + break + + if shared.state.interrupted: + break + + with torch.autocast("cuda"): + c = cond_model([text]) + + x = x.to(devices.device) + loss = shared.sd_model(x.unsqueeze(0), c)[0] + del x + + losses[hypernetwork.step % losses.shape[0]] = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pbar.set_description(f"loss: {losses.mean():.7f}") + + if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') + hypernetwork.save(last_saved_file) + + if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: + last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') + + preview_text = text if preview_image_prompt == "" else preview_image_prompt + + p = processing.StableDiffusionProcessingTxt2Img( + sd_model=shared.sd_model, + prompt=preview_text, + steps=20, + do_not_save_grid=True, + do_not_save_samples=True, + ) + + processed = processing.process_images(p) + image = processed.images[0] + + shared.state.current_image = image + image.save(last_saved_image) + + last_saved_image += f", prompt: {preview_text}" + + shared.state.job_no = hypernetwork.step + + shared.state.textinfo = f""" +

+Loss: {losses.mean():.7f}
+Step: {hypernetwork.step}
+Last prompt: {html.escape(text)}
+Last saved embedding: {html.escape(last_saved_file)}
+Last saved image: {html.escape(last_saved_image)}
+

+""" + + checkpoint = sd_models.select_checkpoint() + + hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint_name = checkpoint.model_name + hypernetwork.save(filename) + + return hypernetwork, filename + + diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py new file mode 100644 index 00000000..811bc31e --- /dev/null +++ b/modules/hypernetworks/ui.py @@ -0,0 +1,43 @@ +import html +import os + +import gradio as gr + +import modules.textual_inversion.textual_inversion +import modules.textual_inversion.preprocess +from modules import sd_hijack, shared +from modules.hypernetworks import hypernetwork + + +def create_hypernetwork(name): + fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") + assert not os.path.exists(fn), f"file {fn} already exists" + + hypernet = modules.hypernetwork.hypernetwork.Hypernetwork(name=name) + hypernet.save(fn) + + shared.reload_hypernetworks() + + return gr.Dropdown.update(choices=sorted([x for x in shared.hypernetworks.keys()])), f"Created: {fn}", "" + + +def train_hypernetwork(*args): + + initial_hypernetwork = shared.loaded_hypernetwork + + try: + sd_hijack.undo_optimizations() + + hypernetwork, filename = modules.hypernetwork.hypernetwork.train_hypernetwork(*args) + + res = f""" +Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. +Hypernetwork saved to {html.escape(filename)} +""" + return res, "" + except Exception: + raise + finally: + shared.loaded_hypernetwork = initial_hypernetwork + sd_hijack.apply_optimizations() + diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f873049a..f07ec041 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -37,7 +37,7 @@ def apply_optimizations(): def undo_optimizations(): - from modules.hypernetwork import hypernetwork + from modules.hypernetworks import hypernetwork ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 27e571fc..3349b9c3 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from ldm.util import default from einops import rearrange from modules import shared -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: diff --git a/modules/shared.py b/modules/shared.py index 375e3afb..1dc2ccf2 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.sd_models import modules.styles import modules.devices as devices from modules import sd_samplers -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path sd_model_file = os.path.join(script_path, 'model.ckpt') diff --git a/modules/ui.py b/modules/ui.py index f57f32db..42e5d866 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -39,7 +39,7 @@ import modules.generation_parameters_copypaste from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui -import modules.hypernetwork.ui +import modules.hypernetworks.ui # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 16918c99..cddb192a 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,7 @@ import modules.scripts as scripts import gradio as gr from modules import images -from modules.hypernetwork import hypernetwork +from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared diff --git a/webui.py b/webui.py index ba2156c8..faa38a0d 100644 --- a/webui.py +++ b/webui.py @@ -29,7 +29,7 @@ from modules import devices from modules import modelloader from modules.paths import script_path from modules.shared import cmd_opts -import modules.hypernetwork.hypernetwork +import modules.hypernetworks.hypernetwork modelloader.cleanup_models() modules.sd_models.setup_model() -- cgit v1.2.3 From c0484f1b986ce7acb0e3596f6089a191279f5442 Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 10 Oct 2022 22:48:54 -0400 Subject: Add cross-attention optimization from InvokeAI * Add cross-attention optimization from InvokeAI (~30% speed improvement on MPS) * Add command line option for it * Make it default when CUDA is unavailable --- modules/sd_hijack.py | 5 ++- modules/sd_hijack_optimizations.py | 79 ++++++++++++++++++++++++++++++++++++++ modules/shared.py | 5 ++- 3 files changed, 86 insertions(+), 3 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index f07ec041..5a1b167f 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -30,8 +30,11 @@ def apply_optimizations(): elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 + elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + print("Applying cross attention optimization (InvokeAI).") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()): - print("Applying cross attention optimization.") + print("Applying cross attention optimization (Doggettx).") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 3349b9c3..870226c5 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,6 +1,7 @@ import math import sys import traceback +import psutil import torch from torch import einsum @@ -116,6 +117,84 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2) +# -- From https://github.com/invoke-ai/InvokeAI/blob/main/ldm/modules/attention.py (with hypernetworks support added) -- + +mem_total_gb = psutil.virtual_memory().total // (1 << 30) + +def einsum_op_compvis(q, k, v): + s = einsum('b i d, b j d -> b i j', q, k) + s = s.softmax(dim=-1, dtype=s.dtype) + return einsum('b i j, b j d -> b i d', s, v) + +def einsum_op_slice_0(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + return r + +def einsum_op_slice_1(q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) + return r + +def einsum_op_mps_v1(q, k, v): + if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + return einsum_op_compvis(q, k, v) + else: + slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + return einsum_op_slice_1(q, k, v, slice_size) + +def einsum_op_mps_v2(q, k, v): + if mem_total_gb > 8 and q.shape[1] <= 4096: + return einsum_op_compvis(q, k, v) + else: + return einsum_op_slice_0(q, k, v, 1) + +def einsum_op_tensor_mem(q, k, v, max_tensor_mb): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return einsum_op_compvis(q, k, v) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return einsum_op_slice_0(q, k, v, q.shape[0] // div) + return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + +def einsum_op(q, k, v): + if q.device.type == 'mps': + if mem_total_gb >= 32: + return einsum_op_mps_v1(q, k, v) + return einsum_op_mps_v2(q, k, v) + + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return einsum_op_tensor_mem(q, k, v, 32) + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + + hypernetwork = shared.loaded_hypernetwork + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) + + if hypernetwork_layers is not None: + k = self.to_k(hypernetwork_layers[0](context)) * self.scale + v = self.to_v(hypernetwork_layers[1](context)) + else: + k = self.to_k(context) * self.scale + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) + +# -- End of code from https://github.com/invoke-ai/InvokeAI/blob/main/ldm/modules/attention.py -- + def xformers_attention_forward(self, x, context=None, mask=None): h = self.heads q_in = self.to_q(x) diff --git a/modules/shared.py b/modules/shared.py index 1dc2ccf2..20b45f23 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -50,9 +50,10 @@ parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") -parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") -parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") +parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.") +parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") +parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[]) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) -- cgit v1.2.3 From d4ea5f4d8631f778d11efcde397e4a5b8801d43b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 11 Oct 2022 19:03:08 +0300 Subject: add an option to unload models during hypernetwork training to save VRAM --- modules/hypernetworks/hypernetwork.py | 25 +++++++++++++++------- modules/hypernetworks/ui.py | 4 +++- modules/shared.py | 4 ++++ modules/textual_inversion/dataset.py | 29 ++++++++++++++++++-------- modules/textual_inversion/textual_inversion.py | 2 +- 5 files changed, 46 insertions(+), 18 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b081f14e..4700e1ec 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -175,6 +175,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name) + unload = shared.opts.unload_models_when_training if save_hypernetwork_every > 0: hypernetwork_dir = os.path.join(log_directory, "hypernetworks") @@ -188,11 +189,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, else: images_dir = None - cond_model = shared.sd_model.cond_stage_model - shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) + + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) hypernetwork = shared.loaded_hypernetwork weights = hypernetwork.weights() @@ -211,7 +214,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, return hypernetwork, filename pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text) in pbar: + for i, (x, text, cond) in pbar: hypernetwork.step = i + ititial_step if hypernetwork.step > steps: @@ -221,11 +224,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, break with torch.autocast("cuda"): - c = cond_model([text]) - + cond = cond.to(devices.device) x = x.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] + loss = shared.sd_model(x.unsqueeze(0), cond)[0] del x + del cond losses[hypernetwork.step % losses.shape[0]] = loss.item() @@ -244,6 +247,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, preview_text = text if preview_image_prompt == "" else preview_image_prompt + optimizer.zero_grad() + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) + p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, prompt=preview_text, @@ -255,6 +262,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, processed = processing.process_images(p) image = processed.images[0] + if unload: + shared.sd_model.cond_stage_model.to(devices.cpu) + shared.sd_model.first_stage_model.to(devices.cpu) + shared.state.current_image = image image.save(last_saved_image) diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 3541a388..c67facbb 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -5,7 +5,7 @@ import gradio as gr import modules.textual_inversion.textual_inversion import modules.textual_inversion.preprocess -from modules import sd_hijack, shared +from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork @@ -41,5 +41,7 @@ Hypernetwork saved to {html.escape(filename)} raise finally: shared.loaded_hypernetwork = initial_hypernetwork + shared.sd_model.cond_stage_model.to(devices.device) + shared.sd_model.first_stage_model.to(devices.device) sd_hijack.apply_optimizations() diff --git a/modules/shared.py b/modules/shared.py index 20b45f23..c1092ff7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -228,6 +228,10 @@ options_templates.update(options_section(('system', "System"), { "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), })) +options_templates.update(options_section(('training', "Training"), { + "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"), +})) + options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 4d006366..f61f40d3 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -8,14 +8,14 @@ from torchvision import transforms import random import tqdm -from modules import devices +from modules import devices, shared import re re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): self.placeholder_token = placeholder_token @@ -32,6 +32,8 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' + cond_model = shared.sd_model.cond_stage_model + self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): @@ -53,7 +55,13 @@ class PersonalizedBase(Dataset): init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = init_latent.to(devices.cpu) - self.dataset.append((init_latent, filename_tokens)) + if include_cond: + text = self.create_text(filename_tokens) + cond = cond_model([text]).to(devices.cpu) + else: + cond = None + + self.dataset.append((init_latent, filename_tokens, cond)) self.length = len(self.dataset) * repeats @@ -64,6 +72,12 @@ class PersonalizedBase(Dataset): def shuffle(self): self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] + def create_text(self, filename_tokens): + text = random.choice(self.lines) + text = text.replace("[name]", self.placeholder_token) + text = text.replace("[filewords]", ' '.join(filename_tokens)) + return text + def __len__(self): return self.length @@ -72,10 +86,7 @@ class PersonalizedBase(Dataset): self.shuffle() index = self.indexes[i % len(self.indexes)] - x, filename_tokens = self.dataset[index] - - text = random.choice(self.lines) - text = text.replace("[name]", self.placeholder_token) - text = text.replace("[filewords]", ' '.join(filename_tokens)) + x, filename_tokens, cond = self.dataset[index] - return x, text + text = self.create_text(filename_tokens) + return x, text, cond diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index bb05cdc6..35f4bd9e 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -201,7 +201,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini return embedding, filename pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, (x, text) in pbar: + for i, (x, text, _) in pbar: embedding.step = i + ititial_step if embedding.step > steps: -- cgit v1.2.3 From f53f703aebc801c4204182d52bb1e0bef9808e1f Mon Sep 17 00:00:00 2001 From: JC_Array Date: Tue, 11 Oct 2022 18:12:12 -0500 Subject: resolved conflicts, moved settings under interrogate section, settings only show if deepbooru flag is enabled --- modules/deepbooru.py | 2 +- modules/shared.py | 19 +++++++++---------- modules/textual_inversion/preprocess.py | 2 +- modules/ui.py | 2 +- 4 files changed, 12 insertions(+), 13 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 89dcac3c..29529949 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -8,7 +8,7 @@ def get_deepbooru_tags(pil_image): This method is for running only one image at a time for simple use. Used to the img2img interrogate. """ from modules import shared # prevents circular reference - create_deepbooru_process(shared.opts.deepbooru_threshold, shared.opts.deepbooru_sort_alpha) + create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha) shared.deepbooru_process_return["value"] = -1 shared.deepbooru_process_queue.put(pil_image) while shared.deepbooru_process_return["value"] == -1: diff --git a/modules/shared.py b/modules/shared.py index 817203f8..5456c477 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -248,15 +248,20 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -options_templates.update(options_section(('interrogate', "Interrogate Options"), { +interrogate_option_dictionary = { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), - "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"), - "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), -})) + "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)") +} + +if cmd_opts.deepdanbooru: + interrogate_option_dictionary["interrogate_deepbooru_score_threshold"] = OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}) + interrogate_option_dictionary["deepbooru_sort_alpha"] = OptionInfo(True, "Interrogate: deepbooru sort alphabetically", gr.Checkbox) + +options_templates.update(options_section(('interrogate', "Interrogate Options"), interrogate_option_dictionary)) options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), @@ -282,12 +287,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -if cmd_opts.deepdanbooru: - options_templates.update(options_section(('deepbooru-params', "DeepBooru parameters"), { - "deepbooru_sort_alpha": OptionInfo(True, "Sort Alphabetical", gr.Checkbox), - 'deepbooru_threshold': OptionInfo(0.5, "Threshold", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), - })) - class Options: data = None diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index a96388d6..113cecf1 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -29,7 +29,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ shared.interrogator.load() if process_caption_deepbooru: - deepbooru.create_deepbooru_process(opts.deepbooru_threshold, opts.deepbooru_sort_alpha) + deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha) def save_pic_with_caption(image, index): if process_caption: diff --git a/modules/ui.py b/modules/ui.py index 2891fc8c..fa45edca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -317,7 +317,7 @@ def interrogate(image): def interrogate_deepbooru(image): - prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold) + prompt = get_deepbooru_tags(image) return gr_show(True) if prompt is None else prompt -- cgit v1.2.3 From 65b973ac4e547a325f30a05f852b161421af2041 Mon Sep 17 00:00:00 2001 From: supersteve3d <39339941+supersteve3d@users.noreply.github.com> Date: Wed, 12 Oct 2022 08:21:52 +0800 Subject: Update shared.py Correct typo to "Unload VAE and CLIP from VRAM when training" in settings tab. --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index c1092ff7..46bc740c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -229,7 +229,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"), + "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { -- cgit v1.2.3 From d717eb079cd6b7fa7a4f97c0a10d400bdec753fb Mon Sep 17 00:00:00 2001 From: Greg Fuller Date: Tue, 11 Oct 2022 18:02:41 -0700 Subject: Interrogate: add option to include ranks in output Since the UI also allows users to specify ranks, it can be useful to show people what ranks are being returned by interrogate This can also give much better results when feeding the interrogate results back into either img2img or txt2img, especially when trying to generate a specific character or scene for which you have a similar concept image Testing Steps: Launch Webui with command line arg: --deepdanbooru Navigate to img2img tab, use interrogate DeepBooru, verify tags appears as before. Use "Interrogate CLIP", verify prompt appears as before Navigate to Settings tab, enable new option, click "apply settings" Navigate to img2img, Interrogate DeepBooru again, verify that weights appear and are properly formatted. Note that "Interrogate CLIP" prompt is still unchanged In my testing, this change has no effect to "Interrogate CLIP", as it seems to generate a sentence-structured caption, and not a set of tags. (reproduce changes from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/2149/commits/6ed4faac46c45ca7353f228aca9b436bbaba7bc7) --- modules/deepbooru.py | 14 +++++++++----- modules/interrogate.py | 7 +++++-- modules/shared.py | 1 + modules/ui.py | 5 ++--- 4 files changed, 17 insertions(+), 10 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 7e3c0618..32d741e2 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -3,7 +3,7 @@ from concurrent.futures import ProcessPoolExecutor from multiprocessing import get_context -def _load_tf_and_return_tags(pil_image, threshold): +def _load_tf_and_return_tags(pil_image, threshold, include_ranks): import deepdanbooru as dd import tensorflow as tf import numpy as np @@ -52,12 +52,16 @@ def _load_tf_and_return_tags(pil_image, threshold): if result_dict[tag] >= threshold: if tag.startswith("rating:"): continue - result_tags_out.append(tag) + tag_formatted = tag.replace('_', ' ').replace(':', ' ') + if include_ranks: + result_tags_out.append(f'({tag_formatted}:{result_dict[tag]})') + else: + result_tags_out.append(tag_formatted) result_tags_print.append(f'{result_dict[tag]} {tag}') print('\n'.join(sorted(result_tags_print, reverse=True))) - return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') + return ', '.join(result_tags_out) def subprocess_init_no_cuda(): @@ -65,9 +69,9 @@ def subprocess_init_no_cuda(): os.environ["CUDA_VISIBLE_DEVICES"] = "-1" -def get_deepbooru_tags(pil_image, threshold=0.5): +def get_deepbooru_tags(pil_image, threshold=0.5, include_ranks=False): context = get_context('spawn') with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor: - f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, ) + f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, include_ranks) ret = f.result() # will rethrow any exceptions return ret \ No newline at end of file diff --git a/modules/interrogate.py b/modules/interrogate.py index 635e266e..af858cc0 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -123,7 +123,7 @@ class InterrogateModels: return caption[0] - def interrogate(self, pil_image): + def interrogate(self, pil_image, include_ranks=False): res = None try: @@ -156,7 +156,10 @@ class InterrogateModels: for name, topn, items in self.categories: matches = self.rank(image_features, items, top_count=topn) for match, score in matches: - res += ", " + match + if include_ranks: + res += ", " + match + else: + res += f", ({match}:{score})" except Exception: print(f"Error interrogating", file=sys.stderr) diff --git a/modules/shared.py b/modules/shared.py index c1092ff7..3e0bfd72 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -251,6 +251,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), + "interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), diff --git a/modules/ui.py b/modules/ui.py index 1204eef7..f4dbe247 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -311,13 +311,12 @@ def apply_styles(prompt, prompt_neg, style1_name, style2_name): def interrogate(image): - prompt = shared.interrogator.interrogate(image) - + prompt = shared.interrogator.interrogate(image, include_ranks=opts.interrogate_return_ranks) return gr_show(True) if prompt is None else prompt def interrogate_deepbooru(image): - prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold) + prompt = get_deepbooru_tags(image, opts.interrogate_deepbooru_score_threshold, opts.interrogate_return_ranks) return gr_show(True) if prompt is None else prompt -- cgit v1.2.3 From 6ac2ec2b78bc5fabd09cb866dd9a71061d669269 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 07:01:20 +0300 Subject: create dir for hypernetworks --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index c1092ff7..e65e77f8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False config_filename = cmd_opts.ui_settings_file +os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None -- cgit v1.2.3 From 336bd8703c7b4d71f2f096f303599925a30b8167 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 09:00:07 +0300 Subject: just add the deepdanbooru settings unconditionally --- modules/shared.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index f150e024..42e99741 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -249,20 +249,15 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), })) -interrogate_option_dictionary = { +options_templates.update(options_section(('interrogate', "Interrogate Options"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), - "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)") -} - -if cmd_opts.deepdanbooru: - interrogate_option_dictionary["interrogate_deepbooru_score_threshold"] = OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}) - interrogate_option_dictionary["deepbooru_sort_alpha"] = OptionInfo(True, "Interrogate: deepbooru sort alphabetically", gr.Checkbox) - -options_templates.update(options_section(('interrogate', "Interrogate Options"), interrogate_option_dictionary)) + "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), + "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), +})) options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), -- cgit v1.2.3 From c3c8eef9fd5a0c8b26319e32ca4a19b56204e6df Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 20:49:47 +0300 Subject: train: change filename processing to be more simple and configurable train: make it possible to make text files with prompts train: rework scheduler so that there's less repeating code in textual inversion and hypernets train: move epochs setting to options --- javascript/hints.js | 3 ++ modules/hypernetworks/hypernetwork.py | 40 +++++++++------------- modules/shared.py | 3 ++ modules/textual_inversion/dataset.py | 47 +++++++++++++++++++------- modules/textual_inversion/learn_schedule.py | 37 +++++++++++++++++++- modules/textual_inversion/textual_inversion.py | 35 +++++++------------ modules/ui.py | 2 -- 7 files changed, 105 insertions(+), 62 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/hints.js b/javascript/hints.js index b81c181b..d51ee14c 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -81,6 +81,9 @@ titles = { "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", + + "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", + "Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.", } diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8314450a..b6c06d49 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -14,7 +14,7 @@ import torch from torch import einsum from einops import rearrange, repeat import modules.textual_inversion.dataset -from modules.textual_inversion.learn_schedule import LearnSchedule +from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): @@ -223,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if ititial_step > steps: return hypernetwork, filename - schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) - (learn_rate, end_step) = next(schedules) - print(f'Training at rate of {learn_rate} until step {end_step}') - - optimizer = torch.optim.AdamW(weights, lr=learn_rate) + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) - for i, (x, text, cond) in pbar: + for i, entry in pbar: hypernetwork.step = i + ititial_step - if hypernetwork.step > end_step: - try: - (learn_rate, end_step) = next(schedules) - except Exception: - break - tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') - for pg in optimizer.param_groups: - pg['lr'] = learn_rate + scheduler.apply(optimizer, hypernetwork.step) + if scheduler.finished: + break if shared.state.interrupted: break with torch.autocast("cuda"): - cond = cond.to(devices.device) - x = x.to(devices.device) + cond = entry.cond.to(devices.device) + x = entry.latent.to(devices.device) loss = shared.sd_model(x.unsqueeze(0), cond)[0] del x del cond @@ -267,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') - preview_text = text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt optimizer.zero_grad() shared.sd_model.cond_stage_model.to(devices.device) @@ -282,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ) processed = processing.process_images(p) - image = processed.images[0] + image = processed.images[0] if len(processed.images)>0 else None if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) - shared.state.current_image = image - image.save(last_saved_image) - - last_saved_image += f", prompt: {preview_text}" + if image is not None: + shared.state.current_image = image + image.save(last_saved_image) + last_saved_image += f", prompt: {preview_text}" shared.state.job_no = hypernetwork.step @@ -299,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,

Loss: {losses.mean():.7f}
Step: {hypernetwork.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/shared.py b/modules/shared.py index 42e99741..e64e69fc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -231,6 +231,9 @@ options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), + "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), + "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), + "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index f61f40d3..67e90afe 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -11,11 +11,21 @@ import tqdm from modules import devices, shared import re -re_tag = re.compile(r"[a-zA-Z][_\w\d()]+") +re_numbers_at_start = re.compile(r"^[-\d]+\s*") + + +class DatasetEntry: + def __init__(self, filename=None, latent=None, filename_text=None): + self.filename = filename + self.latent = latent + self.filename_text = filename_text + self.cond = None + self.cond_text = None class PersonalizedBase(Dataset): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): + re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None self.placeholder_token = placeholder_token @@ -42,9 +52,18 @@ class PersonalizedBase(Dataset): except Exception: continue + text_filename = os.path.splitext(path)[0] + ".txt" filename = os.path.basename(path) - filename_tokens = os.path.splitext(filename)[0] - filename_tokens = re_tag.findall(filename_tokens) + + if os.path.exists(text_filename): + with open(text_filename, "r", encoding="utf8") as file: + filename_text = file.read() + else: + filename_text = os.path.splitext(filename)[0] + filename_text = re.sub(re_numbers_at_start, '', filename_text) + if re_word: + tokens = re_word.findall(filename_text) + filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens) npimage = np.array(image).astype(np.uint8) npimage = (npimage / 127.5 - 1.0).astype(np.float32) @@ -55,13 +74,13 @@ class PersonalizedBase(Dataset): init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze() init_latent = init_latent.to(devices.cpu) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent) + if include_cond: - text = self.create_text(filename_tokens) - cond = cond_model([text]).to(devices.cpu) - else: - cond = None + entry.cond_text = self.create_text(filename_text) + entry.cond = cond_model([entry.cond_text]).to(devices.cpu) - self.dataset.append((init_latent, filename_tokens, cond)) + self.dataset.append(entry) self.length = len(self.dataset) * repeats @@ -72,10 +91,10 @@ class PersonalizedBase(Dataset): def shuffle(self): self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] - def create_text(self, filename_tokens): + def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) - text = text.replace("[filewords]", ' '.join(filename_tokens)) + text = text.replace("[filewords]", filename_text) return text def __len__(self): @@ -86,7 +105,9 @@ class PersonalizedBase(Dataset): self.shuffle() index = self.indexes[i % len(self.indexes)] - x, filename_tokens, cond = self.dataset[index] + entry = self.dataset[index] + + if entry.cond is None: + entry.cond_text = self.create_text(entry.filename_text) - text = self.create_text(filename_tokens) - return x, text, cond + return entry diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index db720271..2062726a 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -1,6 +1,12 @@ +import tqdm -class LearnSchedule: + +class LearnScheduleIterator: def __init__(self, learn_rate, max_steps, cur_step=0): + """ + specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000 + """ + pairs = learn_rate.split(',') self.rates = [] self.it = 0 @@ -32,3 +38,32 @@ class LearnSchedule: return self.rates[self.it - 1] else: raise StopIteration + + +class LearnRateScheduler: + def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): + self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) + (self.learn_rate, self.end_step) = next(self.schedules) + self.verbose = verbose + + if self.verbose: + print(f'Training at rate of {self.learn_rate} until step {self.end_step}') + + self.finished = False + + def apply(self, optimizer, step_number): + if step_number <= self.end_step: + return + + try: + (self.learn_rate, self.end_step) = next(self.schedules) + except Exception: + self.finished = True + return + + if self.verbose: + tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') + + for pg in optimizer.param_groups: + pg['lr'] = self.learn_rate + diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c5153e4a..fa0e33a2 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -11,7 +11,7 @@ from PIL import Image, PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models import modules.textual_inversion.dataset -from modules.textual_inversion.learn_schedule import LearnSchedule +from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, @@ -172,8 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn - -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -205,7 +204,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -221,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if ititial_step > steps: return embedding, filename - schedules = iter(LearnSchedule(learn_rate, steps, ititial_step)) - (learn_rate, end_step) = next(schedules) - print(f'Training at rate of {learn_rate} until step {end_step}') - - optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate) + scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) + optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) - for i, (x, text, _) in pbar: + for i, entry in pbar: embedding.step = i + ititial_step - if embedding.step > end_step: - try: - (learn_rate, end_step) = next(schedules) - except: - break - tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}') - for pg in optimizer.param_groups: - pg['lr'] = learn_rate + scheduler.apply(optimizer, embedding.step) + if scheduler.finished: + break if shared.state.interrupted: break with torch.autocast("cuda"): - c = cond_model([text]) + c = cond_model([entry.cond_text]) - x = x.to(devices.device) + x = entry.latent.to(devices.device) loss = shared.sd_model(x.unsqueeze(0), c)[0] del x @@ -268,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - preview_text = text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -314,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini

Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(text)}
+Last prompt: {html.escape(entry.cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/ui.py b/modules/ui.py index 2b332267..c42535c8 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1098,7 +1098,6 @@ def create_ui(wrap_gradio_gpu_call): training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) - num_repeats = gr.Number(label='Number of repeats for a single input image per epoch', value=100, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) @@ -1176,7 +1175,6 @@ def create_ui(wrap_gradio_gpu_call): training_width, training_height, steps, - num_repeats, create_image_every, save_embedding_every, template_file, -- cgit v1.2.3 From 698d303b04e293635bfb49c525409f3bcf671dce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 12 Oct 2022 21:55:43 +0300 Subject: deepbooru: added option to use spaces or underscores deepbooru: added option to quote (\) in tags deepbooru/BLIP: write caption to file instead of image filename deepbooru/BLIP: now possible to use both for captions deepbooru: process is stopped even if an exception occurs --- modules/deepbooru.py | 65 ++++++++++++++++++----- modules/shared.py | 2 + modules/textual_inversion/preprocess.py | 92 ++++++++++++++------------------- modules/ui.py | 7 +-- 4 files changed, 95 insertions(+), 71 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/deepbooru.py b/modules/deepbooru.py index 29529949..419e6a9c 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -2,33 +2,44 @@ import os.path from concurrent.futures import ProcessPoolExecutor import multiprocessing import time +import re + +re_special = re.compile(r'([\\()])') def get_deepbooru_tags(pil_image): """ This method is for running only one image at a time for simple use. Used to the img2img interrogate. """ from modules import shared # prevents circular reference - create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, shared.opts.deepbooru_sort_alpha) - shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process_queue.put(pil_image) - while shared.deepbooru_process_return["value"] == -1: - time.sleep(0.2) - tags = shared.deepbooru_process_return["value"] - release_process() - return tags + try: + create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts()) + return get_tags_from_process(pil_image) + finally: + release_process() + + +def create_deepbooru_opts(): + from modules import shared -def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort): + return { + "use_spaces": shared.opts.deepbooru_use_spaces, + "use_escape": shared.opts.deepbooru_escape, + "alpha_sort": shared.opts.deepbooru_sort_alpha, + } + + +def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts): model, tags = get_deepbooru_tags_model() while True: # while process is running, keep monitoring queue for new image pil_image = queue.get() if pil_image == "QUIT": break else: - deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort) + deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts) -def create_deepbooru_process(threshold, alpha_sort): +def create_deepbooru_process(threshold, deepbooru_opts): """ Creates deepbooru process. A queue is created to send images into the process. This enables multiple images to be processed in a row without reloading the model or creating a new process. To return the data, a shared @@ -41,10 +52,23 @@ def create_deepbooru_process(threshold, alpha_sort): shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort)) + shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) shared.deepbooru_process.start() +def get_tags_from_process(image): + from modules import shared + + shared.deepbooru_process_return["value"] = -1 + shared.deepbooru_process_queue.put(image) + while shared.deepbooru_process_return["value"] == -1: + time.sleep(0.2) + caption = shared.deepbooru_process_return["value"] + shared.deepbooru_process_return["value"] = -1 + + return caption + + def release_process(): """ Stops the deepbooru process to return used memory @@ -81,10 +105,15 @@ def get_deepbooru_tags_model(): return model, tags -def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort): +def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts): import deepdanbooru as dd import tensorflow as tf import numpy as np + + alpha_sort = deepbooru_opts['alpha_sort'] + use_spaces = deepbooru_opts['use_spaces'] + use_escape = deepbooru_opts['use_escape'] + width = model.input_shape[2] height = model.input_shape[1] image = np.array(pil_image) @@ -129,4 +158,12 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort) print('\n'.join(sorted(result_tags_print, reverse=True))) - return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') + tags_text = ', '.join(result_tags_out) + + if use_spaces: + tags_text = tags_text.replace('_', ' ') + + if use_escape: + tags_text = re.sub(re_special, r'\\\1', tags_text) + + return tags_text.replace(':', ' ') diff --git a/modules/shared.py b/modules/shared.py index e64e69fc..78b73aae 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -260,6 +260,8 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), + "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), + "deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"), })) options_templates.update(options_section(('ui', "User interface"), { diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 113cecf1..3047bede 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -10,7 +10,28 @@ from modules.shared import opts, cmd_opts if cmd_opts.deepdanbooru: import modules.deepbooru as deepbooru + def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): + try: + if process_caption: + shared.interrogator.load() + + if process_caption_deepbooru: + deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts()) + + preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) + + finally: + + if process_caption: + shared.interrogator.send_blip_to_ram() + + if process_caption_deepbooru: + deepbooru.release_process() + + + +def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): width = process_width height = process_height src = os.path.abspath(process_src) @@ -25,30 +46,28 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ shared.state.textinfo = "Preprocessing..." shared.state.job_count = len(files) - if process_caption: - shared.interrogator.load() - - if process_caption_deepbooru: - deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, opts.deepbooru_sort_alpha) - def save_pic_with_caption(image, index): + caption = "" + if process_caption: - caption = "-" + shared.interrogator.generate_caption(image) - caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png") - elif process_caption_deepbooru: - shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process_queue.put(image) - while shared.deepbooru_process_return["value"] == -1: - time.sleep(0.2) - caption = "-" + shared.deepbooru_process_return["value"] - caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png") - shared.deepbooru_process_return["value"] = -1 - else: - caption = filename - caption = os.path.splitext(caption)[0] - caption = os.path.basename(caption) + caption += shared.interrogator.generate_caption(image) + + if process_caption_deepbooru: + if len(caption) > 0: + caption += ", " + caption += deepbooru.get_tags_from_process(image) + + filename_part = filename + filename_part = os.path.splitext(filename_part)[0] + filename_part = os.path.basename(filename_part) + + basename = f"{index:05}-{subindex[0]}-{filename_part}" + image.save(os.path.join(dst, f"{basename}.png")) + + if len(caption) > 0: + with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: + file.write(caption) - image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png")) subindex[0] += 1 def save_pic(image, index): @@ -93,34 +112,3 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ save_pic(img, index) shared.state.nextjob() - - if process_caption: - shared.interrogator.send_blip_to_ram() - - if process_caption_deepbooru: - deepbooru.release_process() - - -def sanitize_caption(base_path, original_caption, suffix): - operating_system = platform.system().lower() - if (operating_system == "windows"): - invalid_path_characters = "\\/:*?\"<>|" - max_path_length = 259 - else: - invalid_path_characters = "/" #linux/macos - max_path_length = 1023 - caption = original_caption - for invalid_character in invalid_path_characters: - caption = caption.replace(invalid_character, "") - fixed_path_length = len(base_path) + len(suffix) - if fixed_path_length + len(caption) <= max_path_length: - return caption - caption_tokens = caption.split() - new_caption = "" - for token in caption_tokens: - last_caption = new_caption - new_caption = new_caption + token + " " - if (len(new_caption) + fixed_path_length - 1 > max_path_length): - break - print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr) - return last_caption.strip() diff --git a/modules/ui.py b/modules/ui.py index c42535c8..e07ee0e1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1074,11 +1074,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): process_flip = gr.Checkbox(label='Create flipped copies') process_split = gr.Checkbox(label='Split oversized images into two') - process_caption = gr.Checkbox(label='Use BLIP caption as filename') - if cmd_opts.deepdanbooru: - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename') - else: - process_caption_deepbooru = gr.Checkbox(label='Use deepbooru caption as filename', visible=False) + process_caption = gr.Checkbox(label='Use BLIP for caption') + process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False) with gr.Row(): with gr.Column(scale=3): -- cgit v1.2.3 From 78592d404acba7db3baf8d78bdc19266906e684a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 07:40:03 +0300 Subject: remove interrogate option I accidentally deleted --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 78b73aae..9bda45c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -258,6 +258,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), + "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), -- cgit v1.2.3 From bb7baf6b9cb6b4b9fa09b6f07ef997db32fe6e58 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 16:07:18 +0300 Subject: add option to change what's shown in quicksettings bar --- javascript/hints.js | 2 ++ modules/shared.py | 4 ++-- modules/ui.py | 16 +++++++++------- style.css | 1 + 4 files changed, 14 insertions(+), 9 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/hints.js b/javascript/hints.js index 32f10fde..06bbd9e2 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -84,6 +84,8 @@ titles = { "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", "Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.", + + "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual stetting tab. See modules/shared.py for setting names. Requires restart to apply." } diff --git a/modules/shared.py b/modules/shared.py index 5f6101a4..4d3ed625 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -152,7 +152,6 @@ class OptionInfo: self.component_args = component_args self.onchange = onchange self.section = None - self.show_on_main_page = show_on_main_page def options_section(section_identifier, options_dict): @@ -237,7 +236,7 @@ options_templates.update(options_section(('training', "Training"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), @@ -250,6 +249,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), + 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { diff --git a/modules/ui.py b/modules/ui.py index e07ee0e1..a0529860 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1305,6 +1305,9 @@ Requested path was: {f} settings_cols = 3 items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) + quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")] + quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings') + quicksettings_list = [] cols_displayed = 0 @@ -1329,7 +1332,7 @@ Requested path was: {f} gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) - if item.show_on_main_page: + if k in quicksettings_names: quicksettings_list.append((i, k, item)) components.append(dummy_component) else: @@ -1338,7 +1341,11 @@ Requested path was: {f} components.append(component) items_displayed += 1 - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + with gr.Row(): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') + restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') + request_notifications.click( fn=lambda: None, inputs=[], @@ -1346,10 +1353,6 @@ Requested path was: {f} _js='function(){}' ) - with gr.Row(): - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') - restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') - def reload_scripts(): modules.scripts.reload_script_body_only() @@ -1364,7 +1367,6 @@ Requested path was: {f} shared.state.interrupt() settings_interface.gradio_ref.do_restart = True - restart_gradio.click( fn=request_restart, inputs=[], diff --git a/style.css b/style.css index e6fa10b4..55c41971 100644 --- a/style.css +++ b/style.css @@ -488,6 +488,7 @@ input[type="range"]{ #quicksettings > div > div{ max-width: 32em; padding: 0; + margin-right: 0.75em; } canvas[key="mask"] { -- cgit v1.2.3 From a10b0e11fc22cc67b6a3664f2ddd17425d8433a8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 19:22:41 +0300 Subject: options to refresh list of models and hypernetworks --- modules/shared.py | 9 +++++---- modules/ui.py | 33 +++++++++++++++++++++++++++++---- style.css | 21 ++++++++++++++++++++- 3 files changed, 54 insertions(+), 9 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 4d3ed625..d8e3a286 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers +from modules import sd_samplers, sd_models from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -145,13 +145,14 @@ def realesrgan_models_names(): class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange self.section = None + self.refresh = refresh def options_section(section_identifier, options_dict): @@ -236,8 +237,8 @@ options_templates.update(options_section(('training', "Training"), { })) options_templates.update(options_section(('sd', "Stable Diffusion"), { - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), - "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), + "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), diff --git a/modules/ui.py b/modules/ui.py index a0529860..0a58f6be 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -78,6 +78,8 @@ reuse_symbol = '\u267b\ufe0f' # ♻️ art_symbol = '\U0001f3a8' # 🎨 paste_symbol = '\u2199\ufe0f' # ↙ folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 + def plaintext_to_html(text): text = "

" + "
\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "

" @@ -1210,8 +1212,7 @@ def create_ui(wrap_gradio_gpu_call): outputs=[], ) - - def create_setting_component(key): + def create_setting_component(key, is_quicksettings=False): def fun(): return opts.data[key] if key in opts.data else opts.data_labels[key].default @@ -1231,7 +1232,31 @@ def create_ui(wrap_gradio_gpu_call): else: raise Exception(f'bad options item type: {str(t)} for key {key}') - return comp(label=info.label, value=fun, **(args or {})) + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun, **(args or {})) + refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key) + else: + with gr.Row(variant="compact"): + res = comp(label=info.label, value=fun, **(args or {})) + refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key) + + def refresh(): + info.refresh() + refreshed_args = info.component_args() if callable(info.component_args) else info.component_args + res.choices = refreshed_args["choices"] + return gr.update(**(refreshed_args or {})) + + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[res], + ) + else: + res = comp(label=info.label, value=fun, **(args or {})) + + + return res components = [] component_dict = {} @@ -1401,7 +1426,7 @@ Requested path was: {f} with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings"): for i, k, item in quicksettings_list: - component = create_setting_component(k) + component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component settings_interface.gradio_ref = demo diff --git a/style.css b/style.css index 55c41971..ad2a52cc 100644 --- a/style.css +++ b/style.css @@ -228,6 +228,8 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s border-top: 1px solid #eee; border-left: 1px solid #eee; border-right: 1px solid #eee; + + z-index: 300; } .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ @@ -480,17 +482,30 @@ input[type="range"]{ background: #a55000; } +#quicksettings { + gap: 0.4em; +} + #quicksettings > div{ border: none; background: none; + flex: unset; + gap: 0.5em; } #quicksettings > div > div{ max-width: 32em; + min-width: 24em; padding: 0; - margin-right: 0.75em; } +#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork{ + max-width: 2.5em; + min-width: 2.5em; + height: 2.4em; +} + + canvas[key="mask"] { z-index: 12 !important; filter: invert(); @@ -507,3 +522,7 @@ canvas[key="mask"] { z-index: 200; width: 8em; } + +.row.gr-compact{ + overflow: visible; +} -- cgit v1.2.3 From 354ef0da3b1f0fa5c113d04b6c79e3908c848d23 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 13 Oct 2022 20:12:37 +0300 Subject: add hypernetwork multipliers --- modules/hypernetworks/hypernetwork.py | 8 +++++++- modules/shared.py | 5 ++++- modules/ui.py | 5 ++++- scripts/xy_grid.py | 9 ++++++++- style.css | 3 +++ webui.py | 2 +- 6 files changed, 27 insertions(+), 5 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b6c06d49..f1248bb7 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler class HypernetworkModule(torch.nn.Module): + multiplier = 1.0 + def __init__(self, dim, state_dict=None): super().__init__() @@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module): self.to(devices.device) def forward(self, x): - return x + (self.linear2(self.linear1(x))) + return x + (self.linear2(self.linear1(x))) * self.multiplier + + +def apply_strength(value=None): + HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength class Hypernetwork: diff --git a/modules/shared.py b/modules/shared.py index d8e3a286..5901e605 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -238,7 +238,8 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), - "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), @@ -348,6 +349,8 @@ class Options: item = self.data_labels.get(key) item.onchange = func + func() + def dumpjson(self): d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} return json.dumps(d) diff --git a/modules/ui.py b/modules/ui.py index 0a58f6be..673014f2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1244,7 +1244,10 @@ def create_ui(wrap_gradio_gpu_call): def refresh(): info.refresh() refreshed_args = info.component_args() if callable(info.component_args) else info.component_args - res.choices = refreshed_args["choices"] + + for k, v in refreshed_args.items(): + setattr(res, k, v) + return gr.update(**(refreshed_args or {})) refresh_button.click( diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 02931ae6..efb63af5 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -107,6 +107,10 @@ def apply_hypernetwork(p, x, xs): hypernetwork.load_hypernetwork(name) +def apply_hypernetwork_strength(p, x, xs): + hypernetwork.apply_strength(x) + + def confirm_hypernetworks(p, xs): for x in xs: if x.lower() in ["", "none"]: @@ -165,6 +169,7 @@ axis_options = [ AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers), AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints), AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks), + AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None), @@ -250,7 +255,7 @@ class Script(scripts.Script): y_values = gr.Textbox(label="Y values", visible=False, lines=1) draw_legend = gr.Checkbox(label='Draw legend', value=True) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=True) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False) return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] @@ -377,6 +382,8 @@ class Script(scripts.Script): modules.sd_models.reload_model_weights(shared.sd_model) hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + hypernetwork.apply_strength() + opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers diff --git a/style.css b/style.css index ad2a52cc..aa3d379c 100644 --- a/style.css +++ b/style.css @@ -522,6 +522,9 @@ canvas[key="mask"] { z-index: 200; width: 8em; } +#quicksettings .gr-box > div > div > input.gr-text-input { + top: -1.12em; +} .row.gr-compact{ overflow: visible; diff --git a/webui.py b/webui.py index 33ba7905..fe0ce321 100644 --- a/webui.py +++ b/webui.py @@ -72,7 +72,6 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) - def initialize(): modelloader.cleanup_models() modules.sd_models.setup_model() @@ -86,6 +85,7 @@ def initialize(): shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) + shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) def webui(): -- cgit v1.2.3 From bb57f30c2de46cfca5419ad01738a41705f96cc3 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Fri, 14 Oct 2022 10:56:41 +0200 Subject: init --- README.md | 1 + aesthetic_embeddings/insert_embs_here.txt | 0 modules/processing.py | 17 +++++- modules/sd_hijack.py | 80 +++++++++++++++++++++++++- modules/shared.py | 5 ++ modules/textual_inversion/dataset.py | 2 +- modules/textual_inversion/textual_inversion.py | 35 +++++++---- modules/txt2img.py | 11 +++- modules/ui.py | 59 ++++++++++++------- 9 files changed, 172 insertions(+), 38 deletions(-) create mode 100644 aesthetic_embeddings/insert_embs_here.txt (limited to 'modules/shared.py') diff --git a/README.md b/README.md index 859a91b6..7b8d018b 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web - No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) +- Aesthetic, a way to generate images with a specific aesthetic by using clip images embds (implementation of https://github.com/vicgalle/stable-diffusion-aesthetic-gradients) ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/aesthetic_embeddings/insert_embs_here.txt b/aesthetic_embeddings/insert_embs_here.txt new file mode 100644 index 00000000..e69de29b diff --git a/modules/processing.py b/modules/processing.py index d5172f00..9a033759 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -316,11 +316,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() -def process_images(p: StableDiffusionProcessing) -> Processed: +def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, + aesthetic_imgs=None,aesthetic_slerp=False) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + aesthetic_lr = float(aesthetic_lr) + aesthetic_weight = float(aesthetic_weight) + aesthetic_steps = int(aesthetic_steps) + if type(p.prompt) == list: - assert(len(p.prompt) > 0) + assert (len(p.prompt) > 0) else: assert p.prompt is not None @@ -394,7 +399,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) + if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): + shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0) + uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], + p.steps) + if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): + shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, + aesthetic_steps, aesthetic_imgs,aesthetic_slerp) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c81722a0..6d5196fe 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -9,11 +9,14 @@ from torch.nn.functional import silu import modules.textual_inversion.textual_inversion from modules import prompt_parser, devices, sd_hijack_optimizations, shared -from modules.shared import opts, device, cmd_opts +from modules.shared import opts, device, cmd_opts, aesthetic_embeddings from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model +from transformers import CLIPVisionModel, CLIPModel +import torch.optim as optim +import copy attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity @@ -109,13 +112,29 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) +def slerp(low, high, val): + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped + self.clipModel = CLIPModel.from_pretrained( + self.wrapped.transformer.name_or_path + ) + del self.clipModel.vision_model self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer + # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() + self.image_embs_name = None + self.image_embs = None + self.load_image_embs(None) + self.token_mults = {} self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] @@ -136,6 +155,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult + def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None, + aesthetic_slerp=True): + self.slerp = aesthetic_slerp + self.aesthetic_lr = aesthetic_lr + self.aesthetic_weight = aesthetic_weight + self.aesthetic_steps = aesthetic_steps + self.load_image_embs(image_embs_name) + + def load_image_embs(self, image_embs_name): + if image_embs_name is None or len(image_embs_name) == 0: + image_embs_name = None + if image_embs_name is not None and self.image_embs_name != image_embs_name: + self.image_embs_name = image_embs_name + self.image_embs = torch.load(aesthetic_embeddings[self.image_embs_name], map_location=device) + self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) + self.image_embs.requires_grad_(False) + def tokenize_line(self, line, used_custom_terms, hijack_comments): id_end = self.wrapped.tokenizer.eos_token_id @@ -333,7 +369,47 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - + + if len(text[ + 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [ + [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in + remade_batch_tokens] + + tokens = torch.asarray(remade_batch_tokens).to(device) + with torch.enable_grad(): + model = copy.deepcopy(self.clipModel).to(device) + model.requires_grad_(True) + + # We optimize the model to maximize the similarity + optimizer = optim.Adam( + model.text_model.parameters(), lr=self.aesthetic_lr + ) + + for i in range(self.aesthetic_steps): + text_embs = model.get_text_features(input_ids=tokens) + text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) + sim = text_embs @ self.image_embs.T + loss = -sim + optimizer.zero_grad() + loss.mean().backward() + optimizer.step() + + zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + if opts.CLIP_stop_at_last_layers > 1: + zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] + zn = model.text_model.final_layer_norm(zn) + else: + zn = zn.last_hidden_state + model.cpu() + del model + + if self.slerp: + z = slerp(z, zn, self.aesthetic_weight) + else: + z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight + remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 diff --git a/modules/shared.py b/modules/shared.py index 5901e605..cf13a10d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -30,6 +30,8 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") +parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(script_path, 'aesthetic_embeddings'), + help="aesthetic_embeddings directory(default: aesthetic_embeddings)") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") @@ -90,6 +92,9 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None +aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in + os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} + def reload_hypernetworks(): global hypernetworks diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 67e90afe..59b2b021 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -48,7 +48,7 @@ class PersonalizedBase(Dataset): print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC) except Exception: continue diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index fa0e33a2..b12a8e6d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -172,7 +172,15 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn -def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): +def batched(dataset, total, n=1): + for ndx in range(0, total, n): + yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] + + +def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, + create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, + preview_image_prompt, batch_size=1, + gradient_accumulation=1): assert embedding_name, 'embedding not selected' shared.state.textinfo = "Initializing textual inversion training..." @@ -204,7 +212,11 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, + height=training_height, + repeats=shared.opts.training_image_repeats_per_epoch, + placeholder_token=embedding_name, model=shared.sd_model, + device=devices.device, template_file=template_file) hijack = sd_hijack.model_hijack @@ -223,7 +235,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) - pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) + pbar = tqdm.tqdm(enumerate(batched(ds, steps - ititial_step, batch_size)), total=steps - ititial_step) for i, entry in pbar: embedding.step = i + ititial_step @@ -235,17 +247,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini break with torch.autocast("cuda"): - c = cond_model([entry.cond_text]) + c = cond_model([e.cond_text for e in entry]) + + x = torch.stack([e.latent for e in entry]).to(devices.device) + loss = shared.sd_model(x, c)[0] - x = entry.latent.to(devices.device) - loss = shared.sd_model(x.unsqueeze(0), c)[0] del x losses[embedding.step % losses.shape[0]] = loss.item() - optimizer.zero_grad() loss.backward() - optimizer.step() + if ((i + 1) % gradient_accumulation == 0) or (i + 1 == steps - ititial_step): + optimizer.step() + optimizer.zero_grad() + epoch_num = embedding.step // len(ds) epoch_step = embedding.step - (epoch_num * len(ds)) + 1 @@ -259,7 +274,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') - preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt + preview_text = entry[0].cond_text if preview_image_prompt == "" else preview_image_prompt p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, @@ -305,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini

Loss: {losses.mean():.7f}
Step: {embedding.step}
-Last prompt: {html.escape(entry.cond_text)}
+Last prompt: {html.escape(entry[-1].cond_text)}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

diff --git a/modules/txt2img.py b/modules/txt2img.py index e985242b..78342024 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -6,7 +6,14 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, + restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, + subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, + height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, + aesthetic_lr=0, + aesthetic_weight=0, aesthetic_steps=0, + aesthetic_imgs=None, + aesthetic_slerp=False, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -40,7 +47,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: - processed = process_images(p) + processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp) shared.total_tqdm.clear() diff --git a/modules/ui.py b/modules/ui.py index 220fb80b..d961d126 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -24,7 +24,8 @@ import gradio.routes from modules import sd_hijack from modules.paths import script_path -from modules.shared import opts, cmd_opts +from modules.shared import opts, cmd_opts,aesthetic_embeddings + if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags import modules.shared as shared @@ -534,6 +535,14 @@ def create_ui(wrap_gradio_gpu_call): width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + with gr.Group(): + aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50) + + aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) tiling = gr.Checkbox(label='Tiling', value=False) @@ -586,25 +595,30 @@ def create_ui(wrap_gradio_gpu_call): fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), _js="submit", inputs=[ - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, - steps, - sampler_index, - restore_faces, - tiling, - batch_count, - batch_size, - cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, - height, - width, - enable_hr, - scale_latent, - denoising_strength, - ] + custom_inputs, + txt2img_prompt, + txt2img_negative_prompt, + txt2img_prompt_style, + txt2img_prompt_style2, + steps, + sampler_index, + restore_faces, + tiling, + batch_count, + batch_size, + cfg_scale, + seed, + subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, + height, + width, + enable_hr, + scale_latent, + denoising_strength, + aesthetic_lr, + aesthetic_weight, + aesthetic_steps, + aesthetic_imgs, + aesthetic_slerp + ] + custom_inputs, outputs=[ txt2img_gallery, generation_info, @@ -1097,6 +1111,9 @@ def create_ui(wrap_gradio_gpu_call): template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt")) training_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + batch_size = gr.Slider(minimum=1, maximum=64, step=1, label="Batch Size", value=4) + gradient_accumulation = gr.Slider(minimum=1, maximum=256, step=1, label="Gradient accumulation", + value=1) steps = gr.Number(label='Max steps', value=100000, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) @@ -1180,6 +1197,8 @@ def create_ui(wrap_gradio_gpu_call): template_file, save_image_with_stored_embedding, preview_image_prompt, + batch_size, + gradient_accumulation ], outputs=[ ti_output, -- cgit v1.2.3 From fdef8253a43ca5135923092ca9b85e878d980869 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 14 Oct 2022 04:42:53 -0400 Subject: Add 'interrogate' and 'all' choices to --use-cpu * Add 'interrogate' and 'all' choices to --use-cpu * Change type for --use-cpu argument to str.lower, so that choices are case insensitive --- modules/devices.py | 2 +- modules/interrogate.py | 14 +++++++------- modules/shared.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/devices.py b/modules/devices.py index 03ef58f1..eb422583 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,7 +34,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() +device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() dtype = torch.float16 dtype_vae = torch.float16 diff --git a/modules/interrogate.py b/modules/interrogate.py index af858cc0..9263d65a 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -55,7 +55,7 @@ class InterrogateModels: model, preprocess = clip.load(clip_model_name) model.eval() - model = model.to(shared.device) + model = model.to(devices.device_interrogate) return model, preprocess @@ -65,14 +65,14 @@ class InterrogateModels: if not shared.cmd_opts.no_half: self.blip_model = self.blip_model.half() - self.blip_model = self.blip_model.to(shared.device) + self.blip_model = self.blip_model.to(devices.device_interrogate) if self.clip_model is None: self.clip_model, self.clip_preprocess = self.load_clip_model() if not shared.cmd_opts.no_half: self.clip_model = self.clip_model.half() - self.clip_model = self.clip_model.to(shared.device) + self.clip_model = self.clip_model.to(devices.device_interrogate) self.dtype = next(self.clip_model.parameters()).dtype @@ -99,11 +99,11 @@ class InterrogateModels: text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] top_count = min(top_count, len(text_array)) - text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device) + text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate) text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features /= text_features.norm(dim=-1, keepdim=True) - similarity = torch.zeros((1, len(text_array))).to(shared.device) + similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate) for i in range(image_features.shape[0]): similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) similarity /= image_features.shape[0] @@ -116,7 +116,7 @@ class InterrogateModels: transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) + ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) with torch.no_grad(): caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) @@ -140,7 +140,7 @@ class InterrogateModels: res = caption - clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) + clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate) precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext with torch.no_grad(), precision_scope("cuda"): diff --git a/modules/shared.py b/modules/shared.py index 5901e605..b6a5c1a8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -54,7 +54,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") -parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[]) +parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -76,8 +76,8 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl cmd_opts = parser.parse_args() -devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ -(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer']) +devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ +(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer']) device = devices.device -- cgit v1.2.3 From 43f926aad1b77a4bb642c1d173adfae1f56cf42d Mon Sep 17 00:00:00 2001 From: Gugubo <29143981+Gugubo@users.noreply.github.com> Date: Fri, 14 Oct 2022 17:06:51 +0200 Subject: Add option to prevent empty spots in grid (1/2) --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index b6a5c1a8..159f504f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -175,6 +175,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "grid_format": OptionInfo('png', 'File format for grids'), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), + "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), -- cgit v1.2.3 From a8eeb2b7ad0c43ad60ac2ba8bd299b9cb265fdd3 Mon Sep 17 00:00:00 2001 From: Ljzd-PRO <63289359+Ljzd-PRO@users.noreply.github.com> Date: Thu, 13 Oct 2022 02:03:08 +0800 Subject: add `--lowram` parameter load models to VRM instead of RAM (for machines which have bigger VRM than RAM such as free Google Colab server) --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 159f504f..cd4a4714 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -34,6 +34,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_ parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") +parser.add_argument("--lowram", action='store_true', help="load models to VRM instead of RAM (for machines which have bigger VRM than RAM such as free Google Colab server)") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") -- cgit v1.2.3 From bb295f54785ac36dc6aa6f7103a3431464440fc3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 20:03:41 +0300 Subject: rework the code for lowram a bit --- modules/sd_models.py | 12 ++---------- modules/shared.py | 3 ++- 2 files changed, 4 insertions(+), 11 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 78a198b9..3a01c93d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -134,11 +134,7 @@ def load_model_weights(model, checkpoint_info): print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - if shared.cmd_opts.lowram: - print("Load to VRAM if GPU is available (low RAM)") - pl_sd = torch.load(checkpoint_file) - else: - pl_sd = torch.load(checkpoint_file, map_location="cpu") + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") @@ -164,11 +160,7 @@ def load_model_weights(model, checkpoint_info): if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") - if shared.cmd_opts.lowram: - print("Load to VRAM if GPU is available (low RAM)") - vae_ckpt = torch.load(vae_file) - else: - vae_ckpt = torch.load(vae_file, map_location="cpu") + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} diff --git a/modules/shared.py b/modules/shared.py index cd4a4714..695d29b6 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -34,7 +34,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_ parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") -parser.add_argument("--lowram", action='store_true', help="load models to VRM instead of RAM (for machines which have bigger VRM than RAM such as free Google Colab server)") +parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") @@ -81,6 +81,7 @@ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.devic (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer']) device = devices.device +weight_load_location = None if cmd_opts.lowram else "cpu" batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram -- cgit v1.2.3 From 03d62538aebeff51713619fe808c953bdb70193d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 14 Oct 2022 22:43:55 +0300 Subject: remove duplicate code for log loss, add step, make it read from options rather than gradio input --- modules/hypernetworks/hypernetwork.py | 20 ++++-------- modules/shared.py | 3 +- modules/textual_inversion/textual_inversion.py | 44 ++++++++++++++++++-------- modules/ui.py | 3 -- 4 files changed, 38 insertions(+), 32 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index edb8cba1..59c7ac6e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -15,6 +15,7 @@ import torch from torch import einsum from einops import rearrange, repeat import modules.textual_inversion.dataset +from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -210,7 +211,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) if unload: shared.sd_model.cond_stage_model.to(devices.cpu) @@ -263,19 +264,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') hypernetwork.save(last_saved_file) - if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0: - write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True - - with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout: - - csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"]) - - if write_csv_header: - csv_writer.writeheader() - - csv_writer.writerow({"step": hypernetwork.step, - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate}) + textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') diff --git a/modules/shared.py b/modules/shared.py index 695d29b6..d41a7ab3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -236,7 +236,8 @@ options_templates.update(options_section(('training', "Training"), { "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), - "training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), + "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), + "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), })) options_templates.update(options_section(('sd', "Stable Diffusion"), { diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1f5ace6f..da0d77a0 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -173,6 +173,32 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn +def write_loss(log_directory, filename, step, epoch_len, values): + if shared.opts.training_write_csv_every == 0: + return + + if step % shared.opts.training_write_csv_every != 0: + return + + write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True + + with open(os.path.join(log_directory, filename), "a+", newline='') as fout: + csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) + + if write_csv_header: + csv_writer.writeheader() + + epoch = step // epoch_len + epoch_step = step - epoch * epoch_len + + csv_writer.writerow({ + "step": step + 1, + "epoch": epoch + 1, + "epoch_step": epoch_step + 1, + **values, + }) + + def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): assert embedding_name, 'embedding not selected' @@ -257,20 +283,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') embedding.save(last_saved_file) - if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0: - write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True - - with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout: - - csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss", "learn_rate"]) - - if write_csv_header: - csv_writer.writeheader() - - csv_writer.writerow({"epoch": epoch_num + 1, - "epoch_step": epoch_step - 1, - "loss": f"{losses.mean():.7f}", - "learn_rate": scheduler.learn_rate}) + write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { + "loss": f"{losses.mean():.7f}", + "learn_rate": scheduler.learn_rate + }) if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') diff --git a/modules/ui.py b/modules/ui.py index be4a43a7..a08ffc9b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1172,7 +1172,6 @@ def create_ui(wrap_gradio_gpu_call): training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) steps = gr.Number(label='Max steps', value=100000, precision=0) create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0) - write_csv_every = gr.Number(label='Save an csv containing the loss to log directory every N steps, 0 to disable', value=500, precision=0) save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0) save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True) preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False) @@ -1251,7 +1250,6 @@ def create_ui(wrap_gradio_gpu_call): steps, create_image_every, save_embedding_every, - write_csv_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, @@ -1274,7 +1272,6 @@ def create_ui(wrap_gradio_gpu_call): steps, create_image_every, save_embedding_every, - write_csv_every, template_file, preview_from_txt2img, *txt2img_preview_params, -- cgit v1.2.3 From e21f01f64504bc651da6e85216474bbd35ee010d Mon Sep 17 00:00:00 2001 From: Rae Fu Date: Thu, 13 Oct 2022 23:00:38 -0600 Subject: add checkpoint cache option to UI for faster model switching switching time reduced from ~1500ms to ~280ms --- modules/sd_models.py | 54 +++++++++++++++++++++++++++++++--------------------- modules/shared.py | 1 + 2 files changed, 33 insertions(+), 22 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a55b4c3..f3660d8d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,4 +1,4 @@ -import glob +import collections import os.path import sys from collections import namedtuple @@ -15,6 +15,7 @@ model_path = os.path.abspath(os.path.join(models_path, model_dir)) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) checkpoints_list = {} +checkpoints_loaded = collections.OrderedDict() try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -132,38 +133,46 @@ def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") + if checkpoint_info not in checkpoints_loaded: + print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") + pl_sd = torch.load(checkpoint_file, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") - sd = get_state_dict_from_checkpoint(pl_sd) + sd = get_state_dict_from_checkpoint(pl_sd) + model.load_state_dict(sd, strict=False) - model.load_state_dict(sd, strict=False) + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + if not shared.cmd_opts.no_half: + model.half() - if not shared.cmd_opts.no_half: - model.half() + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" - vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt" + if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: + vae_file = shared.cmd_opts.vae_path - if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None: - vae_file = shared.cmd_opts.vae_path + if os.path.exists(vae_file): + print(f"Loading VAE weights from: {vae_file}") + vae_ckpt = torch.load(vae_file, map_location="cpu") + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} - if os.path.exists(vae_file): - print(f"Loading VAE weights from: {vae_file}") - vae_ckpt = torch.load(vae_file, map_location="cpu") - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + model.first_stage_model.load_state_dict(vae_dict) - model.first_stage_model.load_state_dict(vae_dict) + model.first_stage_model.to(devices.dtype_vae) - model.first_stage_model.to(devices.dtype_vae) + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU + else: + print(f"Loading weights [{sd_model_hash}] from cache") + checkpoints_loaded.move_to_end(checkpoint_info) + model.load_state_dict(checkpoints_loaded[checkpoint_info]) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file @@ -202,6 +211,7 @@ def reload_model_weights(sd_model, info=None): return if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + checkpoints_loaded.clear() shared.sd_model = load_model() return shared.sd_model diff --git a/modules/shared.py b/modules/shared.py index 5901e605..b2090da1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -238,6 +238,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), -- cgit v1.2.3 From f7ca63937ac83d32483285c3af09afaa356d6276 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 15 Oct 2022 13:23:12 +0300 Subject: bring back scale latent option in settings --- modules/processing.py | 8 ++++---- modules/shared.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/processing.py b/modules/processing.py index 7e2a416d..b9a1660e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -557,11 +557,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] - decoded_samples = decode_first_stage(self.sd_model, samples) + if opts.use_scale_latent_for_hires_fix: + samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") - if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": - decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") else: + decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) batch_images = [] @@ -578,7 +578,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): decoded_samples = decoded_samples.to(shared.device) decoded_samples = 2. * decoded_samples - 1. - samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) + samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) shared.state.nextjob() diff --git a/modules/shared.py b/modules/shared.py index aa69bedf..b4141e67 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -218,6 +218,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), + "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space iamge when doing hires. fix"), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { @@ -256,7 +257,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), - 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { @@ -284,6 +284,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From eef3bc649069d6caaef1274f132de28e528bfa7d Mon Sep 17 00:00:00 2001 From: NO_ob <15161159+NO-ob@users.noreply.github.com> Date: Sat, 15 Oct 2022 11:43:30 +0100 Subject: typo --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index b4141e67..fa30bbb0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -218,7 +218,7 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), - "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space iamge when doing hires. fix"), + "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { -- cgit v1.2.3 From 37d7ffb415cd8c69b3c0bb5f61844dde0b169f78 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sat, 15 Oct 2022 15:59:37 +0200 Subject: fix to tokens lenght, addend embs generator, add new features to edit the embedding before the generation using text --- modules/aesthetic_clip.py | 78 ++++++++++++++++++++++++ modules/processing.py | 148 +++++++++++++++++++++++++++++++--------------- modules/sd_hijack.py | 111 ++++++++++++++++++++++------------ modules/shared.py | 4 ++ modules/txt2img.py | 10 +++- modules/ui.py | 47 ++++++++++++--- 6 files changed, 302 insertions(+), 96 deletions(-) create mode 100644 modules/aesthetic_clip.py (limited to 'modules/shared.py') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py new file mode 100644 index 00000000..f15cfd47 --- /dev/null +++ b/modules/aesthetic_clip.py @@ -0,0 +1,78 @@ +import itertools +import os +from pathlib import Path +import html +import gc + +import gradio as gr +import torch +from PIL import Image +from modules import shared +from modules.shared import device, aesthetic_embeddings +from transformers import CLIPModel, CLIPProcessor + +from tqdm.auto import tqdm + + +def get_all_images_in_folder(folder): + return [os.path.join(folder, f) for f in os.listdir(folder) if + os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)] + + +def check_is_valid_image_file(filename): + return filename.lower().endswith(('.png', '.jpg', '.jpeg')) + + +def batched(dataset, total, n=1): + for ndx in range(0, total, n): + yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] + + +def iter_to_batched(iterable, n=1): + it = iter(iterable) + while True: + chunk = tuple(itertools.islice(it, n)) + if not chunk: + return + yield chunk + + +def generate_imgs_embd(name, folder, batch_size): + # clipModel = CLIPModel.from_pretrained( + # shared.sd_model.cond_stage_model.clipModel.name_or_path + # ) + model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device) + processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path) + + with torch.no_grad(): + embs = [] + for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size), + desc=f"Generating embeddings for {name}"): + if shared.state.interrupted: + break + inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device) + outputs = model.get_image_features(**inputs).cpu() + embs.append(torch.clone(outputs)) + inputs.to("cpu") + del inputs, outputs + + embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) + + # The generated embedding will be located here + path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") + torch.save(embs, path) + + model = model.cpu() + del model + del processor + del embs + gc.collect() + torch.cuda.empty_cache() + res = f""" + Done generating embedding for {name}! + Hypernetwork saved to {html.escape(path)} + """ + shared.update_aesthetic_embeddings() + return gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", + value=sorted(aesthetic_embeddings.keys())[0] if len( + aesthetic_embeddings) > 0 else None), res, "" diff --git a/modules/processing.py b/modules/processing.py index 9a033759..ab68d63a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -20,7 +20,6 @@ import modules.images as images import modules.styles import logging - # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 opt_f = 8 @@ -52,8 +51,13 @@ def get_correct_sampler(p): elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): return sd_samplers.samplers_for_img2img + class StableDiffusionProcessing: - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, + subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, + sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, + restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, + extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids @@ -104,7 +108,8 @@ class StableDiffusionProcessing: class Processed: - def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): + def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, + all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): self.images = images_list self.prompt = p.prompt self.negative_prompt = p.negative_prompt @@ -141,7 +146,8 @@ class Processed: self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) - self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 + self.subseed = int( + self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.all_prompts = all_prompts or [self.prompt] self.all_seeds = all_seeds or [self.seed] @@ -181,39 +187,43 @@ class Processed: return json.dumps(obj) - def infotext(self, p: StableDiffusionProcessing, index): - return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size) + def infotext(self, p: StableDiffusionProcessing, index): + return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], + position_in_batch=index % self.batch_size, iteration=index // self.batch_size) # from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 def slerp(val, low, high): - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) - dot = (low_norm*high_norm).sum(1) + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + dot = (low_norm * high_norm).sum(1) if dot.mean() > 0.9995: return low * val + high * (1 - val) omega = torch.acos(dot) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res -def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): +def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, + p=None): xs = [] # if we have multiple seeds, this means we are working with batch size>1; this then # enables the generation of additional tensors with noise that the sampler will use during its processing. # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): + if p is not None and p.sampler is not None and ( + len(seeds) > 1 and opts.enable_batch_seeds or opts.eta_noise_seed_delta > 0): sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] else: sampler_noises = None for i, seed in enumerate(seeds): - noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) + noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else ( + shape[0], seed_resize_from_h // 8, seed_resize_from_w // 8) subnoise = None if subseeds is not None: @@ -241,7 +251,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see dx = max(-dx, 0) dy = max(-dy, 0) - x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] + x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w] noise = x if sampler_noises is not None: @@ -293,14 +303,20 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration "Seed": all_seeds[index], "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", - "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')), + "Model hash": getattr(p, 'sd_model_hash', + None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), + "Model": ( + None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace( + ',', '').replace(':', '')), + "Hypernet": ( + None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace( + ':', '')), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), - "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), + "Seed resize from": ( + None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Clip skip": None if clip_skip <= 1 else clip_skip, @@ -309,7 +325,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params.update(p.extra_generation_params) - generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join( + [k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" @@ -317,7 +334,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, - aesthetic_imgs=None,aesthetic_slerp=False) -> Processed: + aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" aesthetic_lr = float(aesthetic_lr) @@ -385,7 +404,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh for n in range(p.n_iter): if state.skipped: state.skipped = False - + if state.interrupted: break @@ -396,16 +415,19 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh if (len(prompts) == 0): break - #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) - #c = p.sd_model.get_learned_conditioning(prompts) + # uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) + # c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): - shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0) + shared.sd_model.cond_stage_model.set_aesthetic_params() uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, - aesthetic_steps, aesthetic_imgs,aesthetic_slerp) + aesthetic_steps, aesthetic_imgs, + aesthetic_slerp, aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: @@ -413,13 +435,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh comments[comment] = 1 if p.n_iter > 1: - shared.state.job = f"Batch {n+1} out of {p.n_iter}" + shared.state.job = f"Batch {n + 1} out of {p.n_iter}" with devices.autocast(): - samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) + samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, + subseed_strength=p.subseed_strength) if state.interrupted or state.skipped: - # if we are interrupted, sample returns just noise # use the image collected previously in sampler loop samples_ddim = shared.state.current_latent @@ -445,7 +467,9 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh if p.restore_faces: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: - images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") + images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], + opts.samples_format, info=infotext(n, i), p=p, + suffix="-before-face-restoration") devices.torch_gc() @@ -456,7 +480,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: - images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") + images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, + info=infotext(n, i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) if p.overlay_images is not None and i < len(p.overlay_images): @@ -474,7 +499,8 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh image = image.convert('RGB') if opts.samples_save and not p.do_not_save_samples: - images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p) + images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, + info=infotext(n, i), p=p) text = infotext(n, i) infotexts.append(text) @@ -482,7 +508,7 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh image.info["parameters"] = text output_images.append(image) - del x_samples_ddim + del x_samples_ddim devices.torch_gc() @@ -504,10 +530,13 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, + info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() - return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), + subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, + index_of_first_image=index_of_first_image, infotexts=infotexts) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): @@ -543,25 +572,34 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) if not self.enable_hr: - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, + subseeds=subseeds, subseed_strength=self.subseed_strength, + seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) return samples - x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, + subseeds=subseeds, subseed_strength=self.subseed_strength, + seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f - samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2] + samples = samples[:, :, truncate_y // 2:samples.shape[2] - truncate_y // 2, + truncate_x // 2:samples.shape[3] - truncate_x // 2] if self.scale_latent: - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), + mode="bilinear") else: decoded_samples = decode_first_stage(self.sd_model, samples) if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": - decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), mode="bilinear") + decoded_samples = torch.nn.functional.interpolate(decoded_samples, size=(self.height, self.width), + mode="bilinear") else: lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) @@ -585,13 +623,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) - noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, + subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, + steps=self.steps) return samples @@ -599,7 +640,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, **kwargs): + def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, + inpainting_fill=0, inpaint_full_res=True, inpaint_full_res_padding=0, inpainting_mask_invert=0, + **kwargs): super().__init__(**kwargs) self.init_images = init_images @@ -607,7 +650,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.denoising_strength: float = denoising_strength self.init_latent = None self.image_mask = mask - #self.image_unblurred_mask = None + # self.image_unblurred_mask = None self.latent_mask = None self.mask_for_overlay = None self.mask_blur = mask_blur @@ -619,7 +662,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.nmask = None def init(self, all_prompts, all_seeds, all_subseeds): - self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model) + self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, + self.sd_model) crop_region = None if self.image_mask is not None: @@ -628,7 +672,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpainting_mask_invert: self.image_mask = ImageOps.invert(self.image_mask) - #self.image_unblurred_mask = self.image_mask + # self.image_unblurred_mask = self.image_mask if self.mask_blur > 0: self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) @@ -642,7 +686,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): mask = mask.crop(crop_region) self.image_mask = images.resize_image(2, mask, self.width, self.height) - self.paste_to = (x1, y1, x2-x1, y2-y1) + self.paste_to = (x1, y1, x2 - x1, y2 - y1) else: self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height) np_mask = np.array(self.image_mask) @@ -665,7 +709,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.image_mask is not None: image_masked = Image.new('RGBa', (image.width, image.height)) - image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + image_masked.paste(image.convert("RGBA").convert("RGBa"), + mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) self.overlay_images.append(image_masked.convert('RGBA')) @@ -714,12 +759,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): # this needs to be fixed to be done in sample() using actual seeds for batches if self.inpainting_fill == 2: - self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], + all_seeds[ + 0:self.init_latent.shape[ + 0]]) * self.nmask elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, + subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, + seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6d5196fe..192883b2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -14,7 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available import ldm.modules.attention import ldm.modules.diffusionmodules.model -from transformers import CLIPVisionModel, CLIPModel +from tqdm import trange +from transformers import CLIPVisionModel, CLIPModel, CLIPTokenizer import torch.optim as optim import copy @@ -22,21 +23,25 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward + def apply_optimizations(): undo_optimizations() ldm.modules.diffusionmodules.model.nonlinearity = silu - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and ( + 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (8, 6)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward elif cmd_opts.opt_split_attention_v1: print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 - elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): + elif not cmd_opts.disable_opt_split_attention and ( + cmd_opts.opt_split_attention_invokeai or not torch.cuda.is_available()): if not invokeAI_mps_available and shared.device.type == 'mps': - print("The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") + print( + "The InvokeAI cross attention optimization for MPS requires the psutil package which is not installed.") print("Applying v1 cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1 else: @@ -112,14 +117,16 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) + def slerp(low, high, val): - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm*high_norm).sum(1)) + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high return res + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() @@ -128,6 +135,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.wrapped.transformer.name_or_path ) del self.clipModel.vision_model + self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path) self.hijack: StableDiffusionModelHijack = hijack self.tokenizer = wrapped.tokenizer # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() @@ -139,7 +147,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] - tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k] + tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if + '(' in k or ')' in k or '[' in k or ']' in k] for text, ident in tokens_with_parens: mult = 1.0 for c in text: @@ -155,8 +164,13 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def set_aesthetic_params(self, aesthetic_lr, aesthetic_weight, aesthetic_steps, image_embs_name=None, - aesthetic_slerp=True): + def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, + aesthetic_slerp=True, aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False): + self.aesthetic_imgs_text = aesthetic_imgs_text + self.aesthetic_slerp_angle = aesthetic_slerp_angle + self.aesthetic_text_negative = aesthetic_text_negative self.slerp = aesthetic_slerp self.aesthetic_lr = aesthetic_lr self.aesthetic_weight = aesthetic_weight @@ -180,7 +194,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): else: parsed = [[line, 1.0]] - tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] + tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)[ + "input_ids"] fixes = [] remade_tokens = [] @@ -196,18 +211,20 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if token == self.comma_token: last_comma = len(remade_tokens) - elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), 1) % 75 == 0 and last_comma != -1 and len(remade_tokens) - last_comma <= opts.comma_padding_backtrack: + elif opts.comma_padding_backtrack != 0 and max(len(remade_tokens), + 1) % 75 == 0 and last_comma != -1 and len( + remade_tokens) - last_comma <= opts.comma_padding_backtrack: last_comma += 1 reloc_tokens = remade_tokens[last_comma:] reloc_mults = multipliers[last_comma:] remade_tokens = remade_tokens[:last_comma] length = len(remade_tokens) - + rem = int(math.ceil(length / 75)) * 75 - length remade_tokens += [id_end] * rem + reloc_tokens multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults - + if embedding is None: remade_tokens.append(token) multipliers.append(weight) @@ -248,7 +265,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if line in cache: remade_tokens, fixes, multipliers = cache[line] else: - remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, + hijack_comments) token_count = max(current_token_count, token_count) cache[line] = (remade_tokens, fixes, multipliers) @@ -259,7 +277,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id @@ -289,7 +306,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): while i < len(tokens): token = tokens[i] - embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i) + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, + i) mult_change = self.token_mults.get(token) if opts.enable_emphasis else None if mult_change is not None: @@ -312,11 +330,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ovf = remade_tokens[maxlen - 2:] overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + hijack_comments.append( + f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) - remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) @@ -326,23 +345,26 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): hijack_fixes.append(fixes) batch_multipliers.append(multipliers) return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count - + def forward(self, text): use_old = opts.use_old_emphasis_implementation if use_old: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old( + text) else: - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text( + text) self.hijack.comments += hijack_comments if len(used_custom_terms) > 0: - self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) - + self.hijack.comments.append( + "Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + if use_old: self.hijack.fixes = hijack_fixes return self.process_tokens(remade_batch_tokens, batch_multipliers) - + z = None i = 0 while max(map(len, remade_batch_tokens)) != 0: @@ -356,7 +378,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if fix[0] == i: fixes.append(fix[1]) self.hijack.fixes.append(fixes) - + tokens = [] multipliers = [] for j in range(len(remade_batch_tokens)): @@ -378,19 +400,30 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens] tokens = torch.asarray(remade_batch_tokens).to(device) + + model = copy.deepcopy(self.clipModel).to(device) + model.requires_grad_(True) + if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: + text_embs_2 = model.get_text_features( + **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) + if self.aesthetic_text_negative: + text_embs_2 = self.image_embs - text_embs_2 + text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) + img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) + else: + img_embs = self.image_embs + with torch.enable_grad(): - model = copy.deepcopy(self.clipModel).to(device) - model.requires_grad_(True) # We optimize the model to maximize the similarity optimizer = optim.Adam( model.text_model.parameters(), lr=self.aesthetic_lr ) - for i in range(self.aesthetic_steps): + for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"): text_embs = model.get_text_features(input_ids=tokens) text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ self.image_embs.T + sim = text_embs @ img_embs.T loss = -sim optimizer.zero_grad() loss.mean().backward() @@ -405,6 +438,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): model.cpu() del model + zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1) if self.slerp: z = slerp(z, zn, self.aesthetic_weight) else: @@ -413,15 +447,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 - + return z - - + def process_tokens(self, remade_batch_tokens, batch_multipliers): if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] + remade_batch_tokens = [ + [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in + remade_batch_tokens] batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] - + tokens = torch.asarray(remade_batch_tokens).to(device) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) @@ -461,8 +496,8 @@ class EmbeddingsWithFixes(torch.nn.Module): for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: emb = embedding.vec - emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) - tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) + emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) + tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) vecs.append(tensor) diff --git a/modules/shared.py b/modules/shared.py index cf13a10d..7cd608ca 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -95,6 +95,10 @@ loaded_hypernetwork = None aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} +def update_aesthetic_embeddings(): + global aesthetic_embeddings + aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in + os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} def reload_hypernetworks(): global hypernetworks diff --git a/modules/txt2img.py b/modules/txt2img.py index 78342024..eedcdfe0 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -13,7 +13,11 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, - aesthetic_slerp=False, *args): + aesthetic_slerp=False, + aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False, + *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -47,7 +51,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: - processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp) + processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative) shared.total_tqdm.clear() diff --git a/modules/ui.py b/modules/ui.py index d961d126..e98e2113 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,6 +41,7 @@ from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui +import modules.aesthetic_clip # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -449,7 +450,7 @@ def create_toprow(is_img2img): with gr.Row(): negative_prompt = gr.Textbox(label="Negative prompt", elem_id="negative_prompt", show_label=False, placeholder="Negative prompt", lines=2) with gr.Column(scale=1, elem_id="roll_col"): - sh = gr.Button(elem_id="sh", visible=True) + sh = gr.Button(elem_id="sh", visible=True) with gr.Column(scale=1, elem_id="style_neg_col"): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) @@ -536,9 +537,13 @@ def create_ui(wrap_gradio_gpu_call): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Group(): - aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.7) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=50) + aesthetic_lr = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.0001") + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) + aesthetic_steps = gr.Slider(minimum=0, maximum=256, step=1, label="Aesthetic steps", value=5) + with gr.Row(): + aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") + aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) + aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) @@ -617,7 +622,10 @@ def create_ui(wrap_gradio_gpu_call): aesthetic_weight, aesthetic_steps, aesthetic_imgs, - aesthetic_slerp + aesthetic_slerp, + aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative ] + custom_inputs, outputs=[ txt2img_gallery, @@ -721,7 +729,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Row(): inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) - inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=256, step=4, value=32) + inpaint_full_res_padding = gr.Slider(label='Inpaint at full resolution padding, pixels', minimum=0, maximum=1024, step=4, value=32) with gr.TabItem('Batch img2img', id='batch'): hidden = '
Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' @@ -1071,6 +1079,17 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary') + with gr.Tab(label="Create images embedding"): + new_embedding_name_ae = gr.Textbox(label="Name") + process_src_ae = gr.Textbox(label='Source directory') + batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256) + with gr.Row(): + with gr.Column(scale=3): + gr.HTML(value="") + + with gr.Column(): + create_embedding_ae = gr.Button(value="Create images embedding", variant='primary') + with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) @@ -1139,7 +1158,7 @@ def create_ui(wrap_gradio_gpu_call): fn=modules.textual_inversion.ui.create_embedding, inputs=[ new_embedding_name, - initialization_text, + process_src, nvpt, ], outputs=[ @@ -1149,6 +1168,20 @@ def create_ui(wrap_gradio_gpu_call): ] ) + create_embedding_ae.click( + fn=modules.aesthetic_clip.generate_imgs_embd, + inputs=[ + new_embedding_name_ae, + process_src_ae, + batch_ae + ], + outputs=[ + aesthetic_imgs, + ti_output, + ti_outcome, + ] + ) + create_hypernetwork.click( fn=modules.hypernetworks.ui.create_hypernetwork, inputs=[ -- cgit v1.2.3 From 3395ba493f93214cf037d084d45693a37610bd85 Mon Sep 17 00:00:00 2001 From: ddPn08 Date: Sun, 16 Oct 2022 09:24:01 +0900 Subject: Allow specifying the region of ngrok. --- modules/ngrok.py | 8 +++++--- modules/shared.py | 1 + modules/ui.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/ngrok.py b/modules/ngrok.py index 7d03a6df..5c5f349a 100644 --- a/modules/ngrok.py +++ b/modules/ngrok.py @@ -1,12 +1,14 @@ from pyngrok import ngrok, conf, exception -def connect(token, port): +def connect(token, port, region): if token == None: token = 'None' - conf.get_default().auth_token = token + config = conf.PyngrokConfig( + auth_token=token, region=region + ) try: - public_url = ngrok.connect(port).public_url + public_url = ngrok.connect(port, pyngrok_config=config).public_url except exception.PyngrokNgrokError: print(f'Invalid ngrok authtoken, ngrok connection aborted.\n' f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') diff --git a/modules/shared.py b/modules/shared.py index fa30bbb0..dcab0af9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -40,6 +40,7 @@ parser.add_argument("--unload-gfpgan", action='store_true', help="does not do an parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) +parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(models_path, 'GFPGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) diff --git a/modules/ui.py b/modules/ui.py index 08fa72c6..5c0eaf73 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -56,7 +56,7 @@ if not cmd_opts.share and not cmd_opts.listen: if cmd_opts.ngrok != None: import modules.ngrok as ngrok print('ngrok authtoken detected, trying to connect...') - ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860) + ngrok.connect(cmd_opts.ngrok, cmd_opts.port if cmd_opts.port != None else 7860, cmd_opts.ngrok_region) def gr_show(visible=True): -- cgit v1.2.3 From 523140d7805c644700009b8a2483ff4eb4a22304 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 10:23:30 +0200 Subject: ui fix --- modules/aesthetic_clip.py | 3 +-- modules/sd_hijack.py | 3 +-- modules/shared.py | 2 ++ modules/ui.py | 24 ++++++++++++++---------- 4 files changed, 18 insertions(+), 14 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index 68264284..ccb35c73 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -74,5 +74,4 @@ def generate_imgs_embd(name, folder, batch_size): """ shared.update_aesthetic_embeddings() return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", - value=sorted(shared.aesthetic_embeddings.keys())[0] if len( - shared.aesthetic_embeddings) > 0 else None), res, "" + value="None"), res, "" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 01fcb78f..2de2eed5 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -392,8 +392,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - if len(text[ - 0]) != 0 and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: + if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: if not opts.use_old_emphasis_implementation: remade_batch_tokens = [ [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in diff --git a/modules/shared.py b/modules/shared.py index 3c5ffef1..e2c98b2d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -96,11 +96,13 @@ loaded_hypernetwork = None aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} +aesthetic_embeddings = aesthetic_embeddings | {"None": None} def update_aesthetic_embeddings(): global aesthetic_embeddings aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} + aesthetic_embeddings = aesthetic_embeddings | {"None": None} def reload_hypernetworks(): global hypernetworks diff --git a/modules/ui.py b/modules/ui.py index 13ba3142..4069f0d2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -594,19 +594,23 @@ def create_ui(wrap_gradio_gpu_call): height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) with gr.Group(): - aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") - - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + with gr.Accordion("Open for Clip Aesthetic!",open=False): + with gr.Row(): + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) - with gr.Row(): - aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") - aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) - aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + with gr.Row(): + aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), + label="Aesthetic imgs embedding", + value="None") - aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), label="Aesthetic imgs embedding", value=sorted(aesthetic_embeddings.keys())[0] if len(aesthetic_embeddings) > 0 else None) + with gr.Row(): + aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") + aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) + aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) - aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) -- cgit v1.2.3 From 9324cdaa3199d65c182858785dd1eca42b192b8e Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 17:53:56 +0200 Subject: ui fix, re organization of the code --- modules/aesthetic_clip.py | 154 +++++++++++++++++++++++++++++++++-- modules/img2img.py | 14 +++- modules/processing.py | 29 ++----- modules/sd_hijack.py | 102 ++--------------------- modules/sd_models.py | 5 +- modules/shared.py | 14 +++- modules/textual_inversion/dataset.py | 2 +- modules/txt2img.py | 18 ++-- modules/ui.py | 52 +++++++----- 9 files changed, 233 insertions(+), 157 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index ccb35c73..34efa931 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -1,3 +1,4 @@ +import copy import itertools import os from pathlib import Path @@ -7,11 +8,12 @@ import gc import gradio as gr import torch from PIL import Image -from modules import shared -from modules.shared import device -from transformers import CLIPModel, CLIPProcessor +from torch import optim -from tqdm.auto import tqdm +from modules import shared +from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer +from tqdm.auto import tqdm, trange +from modules.shared import opts, device def get_all_images_in_folder(folder): @@ -37,12 +39,39 @@ def iter_to_batched(iterable, n=1): yield chunk +def create_ui(): + with gr.Group(): + with gr.Accordion("Open for Clip Aesthetic!", open=False): + with gr.Row(): + aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", + value=0.9) + aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + + with gr.Row(): + aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', + placeholder="Aesthetic learning rate", value="0.0001") + aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()), + label="Aesthetic imgs embedding", + value="None") + + with gr.Row(): + aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', + placeholder="This text is used to rotate the feature space of the imgs embs", + value="") + aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01, + value=0.1) + aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + + return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative + + def generate_imgs_embd(name, folder, batch_size): # clipModel = CLIPModel.from_pretrained( # shared.sd_model.cond_stage_model.clipModel.name_or_path # ) - model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device) - processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path) + model = shared.clip_model.to(device) + processor = CLIPProcessor.from_pretrained(model.name_or_path) with torch.no_grad(): embs = [] @@ -63,7 +92,6 @@ def generate_imgs_embd(name, folder, batch_size): torch.save(embs, path) model = model.cpu() - del model del processor del embs gc.collect() @@ -74,4 +102,114 @@ def generate_imgs_embd(name, folder, batch_size): """ shared.update_aesthetic_embeddings() return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", - value="None"), res, "" + value="None"), \ + gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), + label="Imgs embedding", + value="None"), res, "" + + +def slerp(low, high, val): + low_norm = low / torch.norm(low, dim=1, keepdim=True) + high_norm = high / torch.norm(high, dim=1, keepdim=True) + omega = torch.acos((low_norm * high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high + return res + + +class AestheticCLIP: + def __init__(self): + self.skip = False + self.aesthetic_steps = 0 + self.aesthetic_weight = 0 + self.aesthetic_lr = 0 + self.slerp = False + self.aesthetic_text_negative = "" + self.aesthetic_slerp_angle = 0 + self.aesthetic_imgs_text = "" + + self.image_embs_name = None + self.image_embs = None + self.load_image_embs(None) + + def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, + aesthetic_slerp=True, aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False): + self.aesthetic_imgs_text = aesthetic_imgs_text + self.aesthetic_slerp_angle = aesthetic_slerp_angle + self.aesthetic_text_negative = aesthetic_text_negative + self.slerp = aesthetic_slerp + self.aesthetic_lr = aesthetic_lr + self.aesthetic_weight = aesthetic_weight + self.aesthetic_steps = aesthetic_steps + self.load_image_embs(image_embs_name) + + def set_skip(self, skip): + self.skip = skip + + def load_image_embs(self, image_embs_name): + if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": + image_embs_name = None + self.image_embs_name = None + if image_embs_name is not None and self.image_embs_name != image_embs_name: + self.image_embs_name = image_embs_name + self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) + self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) + self.image_embs.requires_grad_(False) + + def __call__(self, z, remade_batch_tokens): + if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None: + tokenizer = shared.sd_model.cond_stage_model.tokenizer + if not opts.use_old_emphasis_implementation: + remade_batch_tokens = [ + [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in + remade_batch_tokens] + + tokens = torch.asarray(remade_batch_tokens).to(device) + + model = copy.deepcopy(shared.clip_model).to(device) + model.requires_grad_(True) + if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: + text_embs_2 = model.get_text_features( + **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) + if self.aesthetic_text_negative: + text_embs_2 = self.image_embs - text_embs_2 + text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) + img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) + else: + img_embs = self.image_embs + + with torch.enable_grad(): + + # We optimize the model to maximize the similarity + optimizer = optim.Adam( + model.text_model.parameters(), lr=self.aesthetic_lr + ) + + for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"): + text_embs = model.get_text_features(input_ids=tokens) + text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) + sim = text_embs @ img_embs.T + loss = -sim + optimizer.zero_grad() + loss.mean().backward() + optimizer.step() + + zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + if opts.CLIP_stop_at_last_layers > 1: + zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] + zn = model.text_model.final_layer_norm(zn) + else: + zn = zn.last_hidden_state + model.cpu() + del model + gc.collect() + torch.cuda.empty_cache() + zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1) + if self.slerp: + z = slerp(z, zn, self.aesthetic_weight) + else: + z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight + + return z diff --git a/modules/img2img.py b/modules/img2img.py index 24126774..4ed80c4b 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -56,7 +56,14 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, + aesthetic_lr=0, + aesthetic_weight=0, aesthetic_steps=0, + aesthetic_imgs=None, + aesthetic_slerp=False, + aesthetic_imgs_text="", + aesthetic_slerp_angle=0.15, + aesthetic_text_negative=False, *args): is_inpaint = mode == 1 is_batch = mode == 2 @@ -109,6 +116,11 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) + shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), + aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, + aesthetic_slerp_angle, + aesthetic_text_negative) + if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index 1db26c3e..685f9fcd 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -146,7 +146,8 @@ class Processed: self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1 - self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 + self.subseed = int( + self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.all_prompts = all_prompts or [self.prompt] self.all_seeds = all_seeds or [self.seed] @@ -332,16 +333,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() -def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, - aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False) -> Processed: +def process_images(p: StableDiffusionProcessing) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - aesthetic_lr = float(aesthetic_lr) - aesthetic_weight = float(aesthetic_weight) - aesthetic_steps = int(aesthetic_steps) - if type(p.prompt) == list: assert (len(p.prompt) > 0) else: @@ -417,16 +411,10 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh # uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) # c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): - if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): - shared.sd_model.cond_stage_model.set_aesthetic_params() + shared.aesthetic_clip.set_skip(True) uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) - if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): - shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, - aesthetic_steps, aesthetic_imgs, - aesthetic_slerp, aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative) + shared.aesthetic_clip.set_skip(False) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: @@ -582,7 +570,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f - def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) @@ -600,10 +587,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] + samples = samples[:, :, self.truncate_y // 2:samples.shape[2] - self.truncate_y // 2, + self.truncate_x // 2:samples.shape[3] - self.truncate_x // 2] if opts.use_scale_latent_for_hires_fix: - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), + mode="bilinear") else: decoded_samples = decode_first_stage(self.sd_model, samples) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 5d0590af..227e7670 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -29,8 +29,8 @@ def apply_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = silu - - if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): + if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and ( + 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): print("Applying xformers cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward @@ -118,33 +118,14 @@ class StableDiffusionModelHijack: return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) -def slerp(low, high, val): - low_norm = low / torch.norm(low, dim=1, keepdim=True) - high_norm = high / torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm * high_norm).sum(1)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res - - class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() self.wrapped = wrapped - self.clipModel = CLIPModel.from_pretrained( - self.wrapped.transformer.name_or_path - ) - del self.clipModel.vision_model - self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path) - self.hijack: StableDiffusionModelHijack = hijack - self.tokenizer = wrapped.tokenizer - # self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval() - self.image_embs_name = None - self.image_embs = None - self.load_image_embs(None) self.token_mults = {} - + self.hijack: StableDiffusionModelHijack = hijack + self.tokenizer = wrapped.tokenizer self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0] tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if @@ -164,28 +145,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, - aesthetic_slerp=True, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False): - self.aesthetic_imgs_text = aesthetic_imgs_text - self.aesthetic_slerp_angle = aesthetic_slerp_angle - self.aesthetic_text_negative = aesthetic_text_negative - self.slerp = aesthetic_slerp - self.aesthetic_lr = aesthetic_lr - self.aesthetic_weight = aesthetic_weight - self.aesthetic_steps = aesthetic_steps - self.load_image_embs(image_embs_name) - - def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": - image_embs_name = None - if image_embs_name is not None and self.image_embs_name != image_embs_name: - self.image_embs_name = image_embs_name - self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) - self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) - self.image_embs.requires_grad_(False) - def tokenize_line(self, line, used_custom_terms, hijack_comments): id_end = self.wrapped.tokenizer.eos_token_id @@ -391,58 +350,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): z1 = self.process_tokens(tokens, multipliers) z = z1 if z is None else torch.cat((z, z1), axis=-2) - - if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None: - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [ - [self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in - remade_batch_tokens] - - tokens = torch.asarray(remade_batch_tokens).to(device) - - model = copy.deepcopy(self.clipModel).to(device) - model.requires_grad_(True) - if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: - text_embs_2 = model.get_text_features( - **self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) - if self.aesthetic_text_negative: - text_embs_2 = self.image_embs - text_embs_2 - text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) - img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) - else: - img_embs = self.image_embs - - with torch.enable_grad(): - - # We optimize the model to maximize the similarity - optimizer = optim.Adam( - model.text_model.parameters(), lr=self.aesthetic_lr - ) - - for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"): - text_embs = model.get_text_features(input_ids=tokens) - text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ img_embs.T - loss = -sim - optimizer.zero_grad() - loss.mean().backward() - optimizer.step() - - zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - if opts.CLIP_stop_at_last_layers > 1: - zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] - zn = model.text_model.final_layer_norm(zn) - else: - zn = zn.last_hidden_state - model.cpu() - del model - - zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1) - if self.slerp: - z = slerp(z, zn, self.aesthetic_weight) - else: - z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight - + z = shared.aesthetic_clip(z, remade_batch_tokens) remade_batch_tokens = rem_tokens batch_multipliers = rem_multipliers i += 1 diff --git a/modules/sd_models.py b/modules/sd_models.py index 3aa21ec1..8e4ee435 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -20,7 +20,7 @@ checkpoints_loaded = collections.OrderedDict() try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging + from transformers import logging, CLIPModel logging.set_verbosity_error() except Exception: @@ -196,6 +196,9 @@ def load_model(): sd_hijack.model_hijack.hijack(sd_model) + if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path: + shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path) + sd_model.eval() print(f"Model loaded.") diff --git a/modules/shared.py b/modules/shared.py index e2c98b2d..e19ca779 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -3,6 +3,7 @@ import datetime import json import os import sys +from collections import OrderedDict import gradio as gr import tqdm @@ -94,15 +95,15 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None -aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in - os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} -aesthetic_embeddings = aesthetic_embeddings | {"None": None} +aesthetic_embeddings = {} def update_aesthetic_embeddings(): global aesthetic_embeddings aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} - aesthetic_embeddings = aesthetic_embeddings | {"None": None} + aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings) + +update_aesthetic_embeddings() def reload_hypernetworks(): global hypernetworks @@ -381,6 +382,11 @@ sd_upscalers = [] sd_model = None +clip_model = None + +from modules.aesthetic_clip import AestheticCLIP +aesthetic_clip = AestheticCLIP() + progress_print_out = sys.stdout diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 68ceffe3..23bb4b6a 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -49,7 +49,7 @@ class PersonalizedBase(Dataset): print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC) + image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) except Exception: continue diff --git a/modules/txt2img.py b/modules/txt2img.py index 8f394d05..6cbc50fc 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,12 +1,17 @@ import modules.scripts -from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images +from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \ + StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, cmd_opts import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int,aesthetic_lr=0, +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, + restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, + subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, + height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, + firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, @@ -41,15 +46,17 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) + shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), + aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, + aesthetic_text_negative) + if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) processed = modules.scripts.scripts_txt2img.run(p, *args) if processed is None: - processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative) + processed = process_images(p) shared.total_tqdm.clear() @@ -61,4 +68,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: processed.images = [] return processed.images, generation_info_js, plaintext_to_html(processed.info) - diff --git a/modules/ui.py b/modules/ui.py index 4069f0d2..0e5d73f0 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -43,7 +43,7 @@ from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.aesthetic_clip +import modules.aesthetic_clip as aesthetic_clip import modules.images_history as img_his @@ -593,23 +593,25 @@ def create_ui(wrap_gradio_gpu_call): width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) - with gr.Group(): - with gr.Accordion("Open for Clip Aesthetic!",open=False): - with gr.Row(): - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) - - with gr.Row(): - aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") - aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) - aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), - label="Aesthetic imgs embedding", - value="None") - - with gr.Row(): - aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") - aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) - aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + # with gr.Group(): + # with gr.Accordion("Open for Clip Aesthetic!",open=False): + # with gr.Row(): + # aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) + # aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) + # + # with gr.Row(): + # aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") + # aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) + # aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), + # label="Aesthetic imgs embedding", + # value="None") + # + # with gr.Row(): + # aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") + # aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) + # aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) + + aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui() with gr.Row(): @@ -840,6 +842,9 @@ def create_ui(wrap_gradio_gpu_call): width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) + aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui() + + with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) tiling = gr.Checkbox(label='Tiling', value=False) @@ -944,6 +949,14 @@ def create_ui(wrap_gradio_gpu_call): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, + aesthetic_lr_im, + aesthetic_weight_im, + aesthetic_steps_im, + aesthetic_imgs_im, + aesthetic_slerp_im, + aesthetic_imgs_text_im, + aesthetic_slerp_angle_im, + aesthetic_text_negative_im, ] + custom_inputs, outputs=[ img2img_gallery, @@ -1283,7 +1296,7 @@ def create_ui(wrap_gradio_gpu_call): ) create_embedding_ae.click( - fn=modules.aesthetic_clip.generate_imgs_embd, + fn=aesthetic_clip.generate_imgs_embd, inputs=[ new_embedding_name_ae, process_src_ae, @@ -1291,6 +1304,7 @@ def create_ui(wrap_gradio_gpu_call): ], outputs=[ aesthetic_imgs, + aesthetic_imgs_im, ti_output, ti_outcome, ] -- cgit v1.2.3 From c8045c5ad4f99deb3a19add06e0457de1df62b05 Mon Sep 17 00:00:00 2001 From: SGKoishi Date: Sun, 16 Oct 2022 10:08:23 -0700 Subject: The hide_ui_dir_config flag also restrict write attempt to path settings --- modules/shared.py | 10 ++++++++++ modules/ui.py | 8 +++++++- 2 files changed, 17 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index dcab0af9..c2775603 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -77,6 +77,16 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl cmd_opts = parser.parse_args() +restricted_opts = [ + "samples_filename_pattern", + "outdir_samples", + "outdir_txt2img_samples", + "outdir_img2img_samples", + "outdir_extras_samples", + "outdir_grids", + "outdir_txt2img_grids", + "outdir_save", +] devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer']) diff --git a/modules/ui.py b/modules/ui.py index 7b0d5a92..43dc88fc 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -25,7 +25,7 @@ import gradio.routes from modules import sd_hijack, sd_models from modules.paths import script_path -from modules.shared import opts, cmd_opts +from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags import modules.shared as shared @@ -1430,6 +1430,9 @@ Requested path was: {f} if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False: continue + if cmd_opts.hide_ui_dir_config and key in restricted_opts: + continue + oldval = opts.data.get(key, None) opts.data[key] = value @@ -1447,6 +1450,9 @@ Requested path was: {f} if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() + if cmd_opts.hide_ui_dir_config and key in restricted_opts: + return gr.update(value=oldval), opts.dumpjson() + oldval = opts.data.get(key, None) opts.data[key] = value -- cgit v1.2.3 From 60251c9456f5472784862896c2f97e38feb42482 Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Mon, 17 Oct 2022 06:58:42 +0000 Subject: initial prototype by borrowing contracts --- modules/api/api.py | 60 ++++++++++++++++++++++++++++++++++++++++++++ modules/processing.py | 2 +- modules/shared.py | 2 +- webui.py | 69 +++++++++++++++++++++++++++++---------------------- 4 files changed, 102 insertions(+), 31 deletions(-) create mode 100644 modules/api/api.py (limited to 'modules/shared.py') diff --git a/modules/api/api.py b/modules/api/api.py new file mode 100644 index 00000000..9d7c699d --- /dev/null +++ b/modules/api/api.py @@ -0,0 +1,60 @@ +from modules.api.processing import StableDiffusionProcessingAPI +from modules.processing import StableDiffusionProcessingTxt2Img, process_images +import modules.shared as shared +import uvicorn +from fastapi import FastAPI, Body, APIRouter +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field, Json +import json +import io +import base64 + +app = FastAPI() + +class TextToImageResponse(BaseModel): + images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") + parameters: Json + info: Json + + +class Api: + def __init__(self, txt2img, img2img, run_extras, run_pnginfo): + self.router = APIRouter() + app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) + + def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): + print(txt2imgreq) + p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq)) + p.sd_model = shared.sd_model + print(p) + processed = process_images(p) + + b64images = [] + for i in processed.images: + buffer = io.BytesIO() + i.save(buffer, format="png") + b64images.append(base64.b64encode(buffer.getvalue())) + + response = { + "images": b64images, + "info": processed.js(), + "parameters": json.dumps(vars(txt2imgreq)) + } + + + return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info)) + + + + def img2imgendoint(self): + raise NotImplementedError + + def extrasendoint(self): + raise NotImplementedError + + def pnginfoendoint(self): + raise NotImplementedError + + def launch(self, server_name, port): + app.include_router(self.router) + uvicorn.run(app, host=server_name, port=port) \ No newline at end of file diff --git a/modules/processing.py b/modules/processing.py index deb6125e..4a7c6ccc 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -723,4 +723,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): del x devices.torch_gc() - return samples + return samples \ No newline at end of file diff --git a/modules/shared.py b/modules/shared.py index c2775603..6c6405fd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,7 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) - +parser.add_argument("--api", action='store_true', help="use api=True to launch the api instead of the webui") cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/webui.py b/webui.py index fe0ce321..cd8a99ea 100644 --- a/webui.py +++ b/webui.py @@ -97,40 +97,51 @@ def webui(): os._exit(0) signal.signal(signal.SIGINT, sigint_handler) + + if cmd_opts.api: + from modules.api.api import Api + api = Api(txt2img=modules.txt2img.txt2img, + img2img=modules.img2img.img2img, + run_extras=modules.extras.run_extras, + run_pnginfo=modules.extras.run_pnginfo) - while 1: - - demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) - - app, local_url, share_url = demo.launch( - share=cmd_opts.share, - server_name="0.0.0.0" if cmd_opts.listen else None, - server_port=cmd_opts.port, - debug=cmd_opts.gradio_debug, - auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, - inbrowser=cmd_opts.autolaunch, - prevent_thread_lock=True - ) - - app.add_middleware(GZipMiddleware, minimum_size=1000) + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", + port=cmd_opts.port if cmd_opts.port else 7861) + else: while 1: - time.sleep(0.5) - if getattr(demo, 'do_restart', False): - time.sleep(0.5) - demo.close() - time.sleep(0.5) - break - sd_samplers.set_samplers() + demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) - print('Reloading Custom Scripts') - modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) - print('Reloading modules: modules.ui') - importlib.reload(modules.ui) - print('Refreshing Model List') - modules.sd_models.list_models() - print('Restarting Gradio') + app, local_url, share_url = demo.launch( + share=cmd_opts.share, + server_name="0.0.0.0" if cmd_opts.listen else None, + server_port=cmd_opts.port, + debug=cmd_opts.gradio_debug, + auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, + inbrowser=cmd_opts.autolaunch, + prevent_thread_lock=True + ) + + app.add_middleware(GZipMiddleware, minimum_size=1000) + + while 1: + time.sleep(0.5) + if getattr(demo, 'do_restart', False): + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + + sd_samplers.set_samplers() + + print('Reloading Custom Scripts') + modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) + print('Reloading modules: modules.ui') + importlib.reload(modules.ui) + print('Refreshing Model List') + modules.sd_models.list_models() + print('Restarting Gradio') if __name__ == "__main__": -- cgit v1.2.3 From 8c6a981d5d9ef30381ac2327460285111550acbc Mon Sep 17 00:00:00 2001 From: Michoko Date: Mon, 17 Oct 2022 11:05:05 +0200 Subject: Added dark mode switch Launch the UI in dark mode with the --dark-mode switch --- javascript/ui.js | 7 +++++++ modules/shared.py | 2 +- modules/ui.py | 2 ++ 3 files changed, 10 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/javascript/ui.js b/javascript/ui.js index 9e1bed4c..bfa72885 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -1,5 +1,12 @@ // various functions for interation with ui.py not large enough to warrant putting them in separate files +function go_dark_mode(){ + gradioURL = window.location.href + if (!gradioURL.endsWith('?__theme=dark')) { + window.location.replace(gradioURL + '?__theme=dark'); + } +} + function selected_gallery_index(){ var buttons = gradioApp().querySelectorAll('[style="display: block;"].tabitem .gallery-item') var button = gradioApp().querySelector('[style="display: block;"].tabitem .gallery-item.\\!ring-2') diff --git a/modules/shared.py b/modules/shared.py index c2775603..cbf158e4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -69,13 +69,13 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) +parser.add_argument("--dark-mode", action='store_true', help="launches the UI in dark mode", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) - cmd_opts = parser.parse_args() restricted_opts = [ "samples_filename_pattern", diff --git a/modules/ui.py b/modules/ui.py index 43dc88fc..a0cd052e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1783,6 +1783,8 @@ for filename in sorted(os.listdir(jsdir)): with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: javascript += f"\n" +if cmd_opts.dark_mode: + javascript += "\n\n" if 'gradio_routes_templates_response' not in globals(): def template_response(*args, **kwargs): -- cgit v1.2.3 From 665beebc0825a6fad410c8252f27f6f6f0bd900b Mon Sep 17 00:00:00 2001 From: Michoko Date: Mon, 17 Oct 2022 18:24:24 +0200 Subject: Use of a --theme argument for more flexibility Added possibility to set the theme (light or dark) --- javascript/ui.js | 6 +++--- modules/shared.py | 2 +- modules/ui.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/ui.js b/javascript/ui.js index bfa72885..cfd0dcd3 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -1,9 +1,9 @@ // various functions for interation with ui.py not large enough to warrant putting them in separate files -function go_dark_mode(){ +function set_theme(theme){ gradioURL = window.location.href - if (!gradioURL.endsWith('?__theme=dark')) { - window.location.replace(gradioURL + '?__theme=dark'); + if (!gradioURL.includes('?__theme=')) { + window.location.replace(gradioURL + '?__theme=' + theme); } } diff --git a/modules/shared.py b/modules/shared.py index cbf158e4..fa084c69 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -69,7 +69,7 @@ parser.add_argument("--gradio-img2img-tool", type=str, help='gradio image upload parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) -parser.add_argument("--dark-mode", action='store_true', help="launches the UI in dark mode", default=False) +parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) diff --git a/modules/ui.py b/modules/ui.py index a0cd052e..d41715fa 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1783,8 +1783,8 @@ for filename in sorted(os.listdir(jsdir)): with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: javascript += f"\n" -if cmd_opts.dark_mode: - javascript += "\n\n" +if cmd_opts.theme is not None: + javascript += f"\n\n" if 'gradio_routes_templates_response' not in globals(): def template_response(*args, **kwargs): -- cgit v1.2.3 From cf47d13c1e11fcb7169bac7488d2c39e579ee491 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 17 Oct 2022 21:15:32 +0300 Subject: localization support --- javascript/localization.js | 146 ++++++++++++++++++++++++++ localizations/Put localization files here.txt | 0 modules/localization.py | 31 ++++++ modules/shared.py | 7 +- modules/ui.py | 33 ++++-- script.js | 10 +- style.css | 2 +- 7 files changed, 211 insertions(+), 18 deletions(-) create mode 100644 javascript/localization.js create mode 100644 localizations/Put localization files here.txt create mode 100644 modules/localization.py (limited to 'modules/shared.py') diff --git a/javascript/localization.js b/javascript/localization.js new file mode 100644 index 00000000..e6644635 --- /dev/null +++ b/javascript/localization.js @@ -0,0 +1,146 @@ + +// localization = {} -- the dict with translations is created by the backend + +ignore_ids_for_localization={ + setting_sd_hypernetwork: 'OPTION', + setting_sd_model_checkpoint: 'OPTION', + setting_realesrgan_enabled_models: 'OPTION', + modelmerger_primary_model_name: 'OPTION', + modelmerger_secondary_model_name: 'OPTION', + modelmerger_tertiary_model_name: 'OPTION', + train_embedding: 'OPTION', + train_hypernetwork: 'OPTION', + txt2img_style_index: 'OPTION', + txt2img_style2_index: 'OPTION', + img2img_style_index: 'OPTION', + img2img_style2_index: 'OPTION', + setting_random_artist_categories: 'SPAN', + setting_face_restoration_model: 'SPAN', + setting_realesrgan_enabled_models: 'SPAN', + extras_upscaler_1: 'SPAN', + extras_upscaler_2: 'SPAN', +} + +re_num = /^[\.\d]+$/ +re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u + +original_lines = {} +translated_lines = {} + +function textNodesUnder(el){ + var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false); + while(n=walk.nextNode()) a.push(n); + return a; +} + +function canBeTranslated(node, text){ + if(! text) return false; + if(! node.parentElement) return false; + + parentType = node.parentElement.nodeName + if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false; + + if (parentType=='OPTION' || parentType=='SPAN'){ + pnode = node + for(var level=0; level<4; level++){ + pnode = pnode.parentElement + if(! pnode) break; + + if(ignore_ids_for_localization[pnode.id] == parentType) return false; + } + } + + if(re_num.test(text)) return false; + if(re_emoji.test(text)) return false; + return true +} + +function getTranslation(text){ + if(! text) return undefined + + if(translated_lines[text] === undefined){ + original_lines[text] = 1 + } + + tl = localization[text] + if(tl !== undefined){ + translated_lines[tl] = 1 + } + + return tl +} + +function processTextNode(node){ + text = node.textContent.trim() + + if(! canBeTranslated(node, text)) return + + tl = getTranslation(text) + if(tl !== undefined){ + node.textContent = tl + } +} + +function processNode(node){ + if(node.nodeType == 3){ + processTextNode(node) + return + } + + if(node.title){ + tl = getTranslation(node.title) + if(tl !== undefined){ + node.title = tl + } + } + + if(node.placeholder){ + tl = getTranslation(node.placeholder) + if(tl !== undefined){ + node.placeholder = tl + } + } + + textNodesUnder(node).forEach(function(node){ + processTextNode(node) + }) +} + +function dumpTranslations(){ + dumped = {} + + Object.keys(original_lines).forEach(function(text){ + if(dumped[text] !== undefined) return + + dumped[text] = localization[text] || text + }) + + return dumped +} + +onUiUpdate(function(m){ + m.forEach(function(mutation){ + mutation.addedNodes.forEach(function(node){ + processNode(node) + }) + }); +}) + + +document.addEventListener("DOMContentLoaded", function() { + processNode(gradioApp()) +}) + +function download_localization() { + text = JSON.stringify(dumpTranslations(), null, 4) + + var element = document.createElement('a'); + element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text)); + element.setAttribute('download', "localization.json"); + element.style.display = 'none'; + document.body.appendChild(element); + + element.click(); + + document.body.removeChild(element); +} diff --git a/localizations/Put localization files here.txt b/localizations/Put localization files here.txt new file mode 100644 index 00000000..e69de29b diff --git a/modules/localization.py b/modules/localization.py new file mode 100644 index 00000000..b1810cda --- /dev/null +++ b/modules/localization.py @@ -0,0 +1,31 @@ +import json +import os +import sys +import traceback + +localizations = {} + + +def list_localizations(dirname): + localizations.clear() + + for file in os.listdir(dirname): + fn, ext = os.path.splitext(file) + if ext.lower() != ".json": + continue + + localizations[fn] = os.path.join(dirname, file) + + +def localization_js(current_localization_name): + fn = localizations.get(current_localization_name, None) + data = {} + if fn is not None: + try: + with open(fn, "r", encoding="utf8") as file: + data = json.load(file) + except Exception: + print(f"Error loading localization from {fn}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + return f"var localization = {json.dumps(data)}\n" diff --git a/modules/shared.py b/modules/shared.py index c2775603..2a2b0427 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models +from modules import sd_samplers, sd_models, localization from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -31,6 +31,7 @@ parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") +parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") @@ -103,7 +104,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None - def reload_hypernetworks(): global hypernetworks @@ -151,6 +151,8 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] +localization.list_localizations(cmd_opts.localizations_dir) + def realesrgan_models_names(): import modules.realesrgan_model @@ -296,6 +298,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), + 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { diff --git a/modules/ui.py b/modules/ui.py index 533b1db3..656bab7a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,7 +23,7 @@ import gradio as gr import gradio.utils import gradio.routes -from modules import sd_hijack, sd_models +from modules import sd_hijack, sd_models, localization from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: @@ -1056,10 +1056,10 @@ def create_ui(wrap_gradio_gpu_call): upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): - extras_upscaler_2 = gr.Radio(label='Upscaler 2', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_2 = gr.Radio(label='Upscaler 2', celem_id="extras_upscaler_2", hoices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1) with gr.Group(): @@ -1224,10 +1224,10 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tab(label="Train"): gr.HTML(value="

Train an embedding; must specify a directory with a set of 1:1 ratio images

") with gr.Row(): - train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) + train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())) create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name") with gr.Row(): - train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', choices=[x for x in shared.hypernetworks.keys()]) + train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()]) create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name") learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005") batch_size = gr.Number(label='Batch size', value=1, precision=0) @@ -1376,16 +1376,18 @@ def create_ui(wrap_gradio_gpu_call): else: raise Exception(f'bad options item type: {str(t)} for key {key}') + elem_id = "setting_"+key + if info.refresh is not None: if is_quicksettings: - res = comp(label=info.label, value=fun, **(args or {})) - refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: with gr.Row(variant="compact"): - res = comp(label=info.label, value=fun, **(args or {})) - refresh_button = create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) + res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) + create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: - res = comp(label=info.label, value=fun, **(args or {})) + res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) return res @@ -1509,6 +1511,9 @@ Requested path was: {f} with gr.Row(): request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + + with gr.Row(): reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary') restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary') @@ -1519,6 +1524,13 @@ Requested path was: {f} _js='function(){}' ) + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + def reload_scripts(): modules.scripts.reload_script_body_only() @@ -1784,6 +1796,7 @@ for filename in sorted(os.listdir(jsdir)): with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: javascript += f"\n" +javascript += f"\n" if 'gradio_routes_templates_response' not in globals(): def template_response(*args, **kwargs): diff --git a/script.js b/script.js index 88f2c839..8b3b67e3 100644 --- a/script.js +++ b/script.js @@ -21,20 +21,20 @@ function onUiTabChange(callback){ uiTabChangeCallbacks.push(callback) } -function runCallback(x){ +function runCallback(x, m){ try { - x() + x(m) } catch (e) { (console.error || console.log).call(console, e.message, e); } } -function executeCallbacks(queue) { - queue.forEach(runCallback) +function executeCallbacks(queue, m) { + queue.forEach(function(x){runCallback(x, m)}) } document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ - executeCallbacks(uiUpdateCallbacks); + executeCallbacks(uiUpdateCallbacks, m); const newTab = get_uiCurrentTab(); if ( newTab && ( newTab !== uiCurrentTab ) ) { uiCurrentTab = newTab; diff --git a/style.css b/style.css index 71eb4d20..9dc4b696 100644 --- a/style.css +++ b/style.css @@ -478,7 +478,7 @@ input[type="range"]{ padding: 0; } -#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name{ +#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{ max-width: 2.5em; min-width: 2.5em; height: 2.4em; -- cgit v1.2.3 From 1df3ff25e6fe2e3f308e45f7a6dd37fb4f1988e6 Mon Sep 17 00:00:00 2001 From: Ryan Voots Date: Mon, 17 Oct 2022 12:58:34 -0400 Subject: Add --nowebui as a means of disabling the webui and run on the other port --- modules/shared.py | 3 ++- webui.py | 35 +++++++++++++++++++++++++---------- 2 files changed, 27 insertions(+), 11 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 6c6405fd..8b436970 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,8 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) -parser.add_argument("--api", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") +parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/webui.py b/webui.py index 6b55fbed..6212be01 100644 --- a/webui.py +++ b/webui.py @@ -95,16 +95,34 @@ def initialize(): signal.signal(signal.SIGINT, sigint_handler) -def api(): +def create_api(app): from modules.api.api import Api api = Api(app) + return api + +def wait_on_server(demo=None): + while 1: + time.sleep(0.5) + if demo and getattr(demo, 'do_restart', False): + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + +def api_only(): + initialize() + + app = FastAPI() + app.add_middleware(GZipMiddleware, minimum_size=1000) + api = create_api(app) + + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) def webui(launch_api=False): initialize() while 1: - demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) app, local_url, share_url = demo.launch( @@ -120,15 +138,9 @@ def webui(launch_api=False): app.add_middleware(GZipMiddleware, minimum_size=1000) if (launch_api): - api(app) + create_api(app) - while 1: - time.sleep(0.5) - if getattr(demo, 'do_restart', False): - time.sleep(0.5) - demo.close() - time.sleep(0.5) - break + wait_on_server(demo) sd_samplers.set_samplers() @@ -142,6 +154,9 @@ def webui(launch_api=False): if __name__ == "__main__": + if not cmd_opts.nowebui: + api_only() + if cmd_opts.api: webui(True) else: -- cgit v1.2.3 From 8b02662215917d39f76f86b703a322818d5a8ad4 Mon Sep 17 00:00:00 2001 From: trufty Date: Mon, 17 Oct 2022 10:58:21 -0400 Subject: Disable auto weights swap with config option --- modules/shared.py | 1 + modules/ui.py | 4 ++++ 2 files changed, 5 insertions(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 9603d26e..8a1d1881 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -266,6 +266,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), + "disable_weights_auto_swap": OptionInfo(False, "Disable auto swapping weights to match model hash in prompts"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), diff --git a/modules/ui.py b/modules/ui.py index 1dae4a65..75eb0b0c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -542,6 +542,10 @@ def apply_setting(key, value): if value is None: return gr.update() + # dont allow model to be swapped when model hash exists in prompt + if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: + return gr.update() + if key == "sd_model_checkpoint": ckpt_info = sd_models.get_closet_checkpoint_match(value) -- cgit v1.2.3 From d2f459c5cf9f728256775dc1c3380c7e9a7e27fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 18 Oct 2022 14:22:52 +0300 Subject: clarify the comment for the new option from #2959 and move it to UI section. --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 8a1d1881..c0d87168 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -266,7 +266,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "enable_emphasis": OptionInfo(True, "Emphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), - "disable_weights_auto_swap": OptionInfo(False, "Disable auto swapping weights to match model hash in prompts"), "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), @@ -294,6 +293,7 @@ options_templates.update(options_section(('ui', "User interface"), { "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), + "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "font": OptionInfo("", "Font for image grids that have text"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), -- cgit v1.2.3 From 4c605c5174a9b211c3a88e9aff5f5be92b53fd92 Mon Sep 17 00:00:00 2001 From: DepFA <35278260+dfaker@users.noreply.github.com> Date: Sun, 16 Oct 2022 17:24:06 +0100 Subject: add shared option for update check --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index c0d87168..50dc46ae 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -76,6 +76,7 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) +parser.add_argument("--update-check", action='store_true', help="enable http check to confirm that the currently running version is the most recent release.", default=False) cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.3 From eb299527b1e5d1f83a14641647fca72e8fb305ac Mon Sep 17 00:00:00 2001 From: yfszzx Date: Tue, 18 Oct 2022 20:14:11 +0800 Subject: Image browser --- javascript/images_history.js | 19 ++-- modules/images_history.py | 227 ++++++++++++++++++++++++++++--------------- modules/shared.py | 7 +- modules/ui.py | 2 +- uitest.bat | 2 + uitest.py | 124 +++++++++++++++++++++++ 6 files changed, 289 insertions(+), 92 deletions(-) create mode 100644 uitest.bat create mode 100644 uitest.py (limited to 'modules/shared.py') diff --git a/javascript/images_history.js b/javascript/images_history.js index 3c028bc6..182d730b 100644 --- a/javascript/images_history.js +++ b/javascript/images_history.js @@ -145,9 +145,10 @@ function images_history_enable_del_buttons(){ } function images_history_init(){ - var loaded = gradioApp().getElementById("images_history_reconstruct_directory") - if (loaded){ - var init_status = loaded.querySelector("input").checked + // var loaded = gradioApp().getElementById("images_history_reconstruct_directory") + // if (loaded){ + // var init_status = loaded.querySelector("input").checked + if (gradioApp().getElementById("images_history_finish_render")){ for (var i in images_history_tab_list ){ tab = images_history_tab_list[i]; gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor"); @@ -163,19 +164,17 @@ function images_history_init(){ for (var i in images_history_tab_list){ var tabname = images_history_tab_list[i] tab_btns[i].setAttribute("tabname", tabname); - if (init_status){ - tab_btns[i].addEventListener('click', images_history_click_tab); - } - } - if (init_status){ - tab_btns[0].click(); + // if (!init_status){ + // tab_btns[i].addEventListener('click', images_history_click_tab); + // } + tab_btns[i].addEventListener('click', images_history_click_tab); } } else { setTimeout(images_history_init, 500); } } -var images_history_tab_list = ["txt2img", "img2img", "extras", "saved"]; +var images_history_tab_list = ["custom", "txt2img", "img2img", "extras", "saved"]; setTimeout(images_history_init, 500); document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ diff --git a/modules/images_history.py b/modules/images_history.py index 20324557..d56f3a25 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -4,6 +4,7 @@ import time import hashlib import gradio system_bak_path = "webui_log_and_bak" +browser_tabname = "custom" def is_valid_date(date): try: time.strptime(date, "%Y%m%d") @@ -99,13 +100,15 @@ def auto_sorting(dir_name): date_list.append(today) return sorted(date_list, reverse=True) -def archive_images(dir_name, date_to): +def archive_images(dir_name, date_to): + filenames = [] loads_num =int(opts.images_history_num_per_page * opts.images_history_pages_num) + today = time.strftime("%Y%m%d",time.localtime(time.time())) + date_to = today if date_to is None or date_to == "" else date_to + date_to_bak = date_to if opts.images_history_reconstruct_directory: - date_list = auto_sorting(dir_name) - today = time.strftime("%Y%m%d",time.localtime(time.time())) - date_to = today if date_to is None or date_to == "" else date_to + date_list = auto_sorting(dir_name) for date in date_list: if date <= date_to: path = os.path.join(dir_name, date) @@ -120,7 +123,7 @@ def archive_images(dir_name, date_to): tmparray = [(os.path.getmtime(file), file) for file in filenames ] date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 filenames = [] - date_list = {} + date_list = {date_to:None} date = time.strftime("%Y%m%d",time.localtime(time.time())) for t, f in tmparray: date = time.strftime("%Y%m%d",time.localtime(t)) @@ -133,22 +136,29 @@ def archive_images(dir_name, date_to): date = sort_array[loads_num][2] filenames = [x[1] for x in sort_array] else: - date = None if len(sort_array) == 0 else sort_array[-1][2] + date = date_to if len(sort_array) == 0 else sort_array[-1][2] filenames = [x[1] for x in sort_array] - filenames = [x[1] for x in sort_array if x[2]>= date] - _, image_list, _, visible_num = get_recent_images(1, 0, filenames) + filenames = [x[1] for x in sort_array if x[2]>= date] + num = len(filenames) + last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) + date = date[:4] + "-" + date[4:6] + "-" + date[6:8] + date_to_bak = date_to_bak[:4] + "-" + date_to_bak[4:6] + "-" + date_to_bak[6:8] + load_info = f"Loaded {(num + 1) // opts.images_history_pages_num} pades, {num} images, during {date} - {date_to_bak}" + _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) return ( gradio.Dropdown.update(choices=date_list, value=date_to), - date, + load_info, filenames, 1, image_list, "", - visible_num + "", + visible_num, + last_date_from ) -def newest_click(dir_name, date_to): - return archive_images(dir_name, time.strftime("%Y%m%d",time.localtime(time.time()))) + + def delete_image(delete_num, name, filenames, image_index, visible_num): if name == "": @@ -196,7 +206,29 @@ def get_recent_images(page_index, step, filenames): length = len(filenames) visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num - return page_index, image_list, "", visible_num + return page_index, image_list, "", "", visible_num + +def newest_click(date_to): + if date_to is None: + return time.strftime("%Y%m%d",time.localtime(time.time())), [] + else: + return None, [] +def forward_click(last_date_from, date_to_recorder): + if len(date_to_recorder) == 0: + return None, [] + if last_date_from == date_to_recorder[-1]: + date_to_recorder = date_to_recorder[:-1] + if len(date_to_recorder) == 0: + return None, [] + return date_to_recorder[-1], date_to_recorder[:-1] + +def backward_click(last_date_from, date_to_recorder): + if last_date_from is None or last_date_from == "": + return time.strftime("%Y%m%d",time.localtime(time.time())), [] + if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]: + date_to_recorder.append(last_date_from) + return last_date_from, date_to_recorder + def first_page_click(page_index, filenames): return get_recent_images(1, 0, filenames) @@ -214,13 +246,33 @@ def page_index_change(page_index, filenames): return get_recent_images(page_index, 0, filenames) def show_image_info(tabname_box, num, page_index, filenames): - file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] - return file, num, file + file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] + tm = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + return file, tm, num, file def enable_page_buttons(): return gradio.update(visible=True) +def change_dir(img_dir, date_to): + warning = None + try: + if os.path.exists(img_dir): + try: + f = os.listdir(img_dir) + except: + warning = f"'{img_dir} is not a directory" + else: + warning = "The directory is not exist" + except: + warning = "The format of the directory is incorrect" + if warning is None: + today = time.strftime("%Y%m%d",time.localtime(time.time())) + return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today + else: + return gradio.update(visible=True), gradio.update(visible=False), warning, date_to + def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): + custom_dir = False if tabname == "txt2img": dir_name = opts.outdir_txt2img_samples elif tabname == "img2img": @@ -229,69 +281,85 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): dir_name = opts.outdir_extras_samples elif tabname == "saved": dir_name = opts.outdir_save + else: + custom_dir = True + dir_name = None + + if not custom_dir: + d = dir_name.split("/") + dir_name = d[0] + for p in d[1:]: + dir_name = os.path.join(dir_name, p) + if not os.path.exists(dir_name): + os.makedirs(dir_name) - d = dir_name.split("/") - dir_name = d[0] - for p in d[1:]: - dir_name = os.path.join(dir_name, p) - if not os.path.exists(dir_name): - os.makedirs(dir_name) - - with gr.Column() as page_panel: - with gr.Row(visible=False) as turn_page_buttons: - renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") - first_page = gr.Button('First Page') - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - - with gr.Row(elem_id=tabname + "_images_history"): - with gr.Column(scale=2): - with gr.Row(): - newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") - date_from = gr.Textbox(label="Date from", interactive=False) - date_to = gr.Dropdown(label="Date to") - - history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6) - with gr.Row(): - delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") - delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - - with gr.Column(): - with gr.Row(): - if tabname != "saved": - save_btn = gr.Button('Save') - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') - with gr.Row(): - with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False) - img_file_name = gr.Textbox(value="", label="File Name", interactive=False) + with gr.Column() as page_panel: + with gr.Row(): + img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory") + with gr.Row(visible=False) as warning: + warning_box = gr.Textbox("Message", interactive=False) + with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: + with gr.Column(scale=2): + with gr.Row(): + backward = gr.Button('Backward') + date_to = gr.Dropdown(label="Date to") + forward = gr.Button('Forward') + newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") + with gr.Row(): + load_info = gr.Textbox(show_label=False, interactive=False) + with gr.Row(visible=False) as turn_page_buttons: + renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") + first_page = gr.Button('First Page') + prev_page = gr.Button('Prev Page') + page_index = gr.Number(value=1, label="Page Index") + next_page = gr.Button('Next Page') + end_page = gr.Button('End Page') + + history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num) + with gr.Row(): + delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") + delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - # hiden items - with gr.Row(visible=False): - visible_img_num = gr.Number() - img_path = gr.Textbox(dir_name) - tabname_box = gr.Textbox(tabname) - image_index = gr.Textbox(value=-1) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") - filenames = gr.State() - all_images_list = gr.State() - hidden = gr.Image(type="pil") - info1 = gr.Textbox() - info2 = gr.Textbox() + with gr.Column(): + with gr.Row(): + if tabname != "saved": + save_btn = gr.Button('Save') + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + with gr.Row(): + with gr.Column(): + img_file_info = gr.Textbox(label="Generate Info", interactive=False) + img_file_name = gr.Textbox(value="", label="File Name", interactive=False) + img_file_time= gr.Textbox(value="", label="Create Time", interactive=False) - + + # hiden items + with gr.Row(): #visible=False): + visible_img_num = gr.Number() + date_to_recorder = gr.State([]) + last_date_from = gr.Textbox() + tabname_box = gr.Textbox(tabname) + image_index = gr.Textbox(value=-1) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") + filenames = gr.State() + all_images_list = gr.State() + hidden = gr.Image(type="pil") + info1 = gr.Textbox() + info2 = gr.Textbox() + + img_path.submit(change_dir, inputs=[img_path, date_to], outputs=[warning, main_panel, warning_box, date_to]) #change date - change_date_output = [date_to, date_from, filenames, page_index, history_gallery, img_file_name, visible_img_num] - newest.click(newest_click, inputs=[img_path, date_to], outputs=change_date_output) - date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) - newest.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) - newest.click(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) + change_date_output = [date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from] + + date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) + date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) + date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + + newest.click(newest_click, inputs=[date_to], outputs=[date_to, date_to_recorder]) + forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) + backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) + #delete delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) @@ -301,7 +369,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): #turn page gallery_inputs = [page_index, filenames] - gallery_outputs = [page_index, history_gallery, img_file_name, visible_img_num] + gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num] first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) @@ -317,12 +385,14 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, image_index, hidden]) + set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden]) img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') + + def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): global opts; opts = sys_opts @@ -330,10 +400,11 @@ def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: - for tab in ["txt2img", "img2img", "extras", "saved"]: + for tab in [browser_tabname, "txt2img", "img2img", "extras", "saved"]: with gr.Tab(tab): - with gr.Blocks(analytics_enabled=False) as images_history_img2img: + with gr.Blocks(analytics_enabled=False) : show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) + #gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) + gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_finish_render", visible=False) return images_history diff --git a/modules/shared.py b/modules/shared.py index c2ea4186..1811018d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -309,10 +309,11 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -options_templates.update(options_section(('images-history', "Images history"), { - "images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), +options_templates.update(options_section(('images-history', "Images Browser"), { + #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"), - "images_history_pages_num": OptionInfo(6, "Maximum number of pages per load "), + "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "), + "images_history_grid_num": OptionInfo(6, "Number of grids in each row"), })) diff --git a/modules/ui.py b/modules/ui.py index 43dc88fc..85abac4d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1548,7 +1548,7 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (images_history, "History", "images_history"), + (images_history, "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), diff --git a/uitest.bat b/uitest.bat new file mode 100644 index 00000000..ae863af6 --- /dev/null +++ b/uitest.bat @@ -0,0 +1,2 @@ +venv\Scripts\python.exe uitest.py +pause diff --git a/uitest.py b/uitest.py new file mode 100644 index 00000000..393e2d81 --- /dev/null +++ b/uitest.py @@ -0,0 +1,124 @@ +import os +import threading +import time +import importlib +import signal +import threading + +from modules.paths import script_path + +from modules import devices, sd_samplers +import modules.codeformer_model as codeformer +import modules.extras +import modules.face_restoration +import modules.gfpgan_model as gfpgan +import modules.img2img + +import modules.lowvram +import modules.paths +import modules.scripts +import modules.sd_hijack +import modules.sd_models +import modules.shared as shared +import modules.txt2img + +import modules.ui +from modules import devices +from modules import modelloader +from modules.paths import script_path +from modules.shared import cmd_opts + +modelloader.cleanup_models() +modules.sd_models.setup_model() +codeformer.setup_model(cmd_opts.codeformer_models_path) +gfpgan.setup_model(cmd_opts.gfpgan_models_path) +shared.face_restorers.append(modules.face_restoration.FaceRestoration()) +modelloader.load_upscalers() +queue_lock = threading.Lock() + + +def wrap_queued_call(func): + def f(*args, **kwargs): + with queue_lock: + res = func(*args, **kwargs) + + return res + + return f + + +def wrap_gradio_gpu_call(func, extra_outputs=None): + def f(*args, **kwargs): + devices.torch_gc() + + shared.state.sampling_step = 0 + shared.state.job_count = -1 + shared.state.job_no = 0 + shared.state.job_timestamp = shared.state.get_job_timestamp() + shared.state.current_latent = None + shared.state.current_image = None + shared.state.current_image_sampling_step = 0 + shared.state.interrupted = False + shared.state.textinfo = None + + with queue_lock: + res = func(*args, **kwargs) + + shared.state.job = "" + shared.state.job_count = 0 + + devices.torch_gc() + + return res + + return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) + + +modules.scripts.load_scripts(os.path.join(script_path, "scripts")) + +shared.sd_model = None #modules.sd_models.load_model() +#shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) + + +def webui(): + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f'Interrupted with signal {sig} in {frame}') + os._exit(0) + + signal.signal(signal.SIGINT, sigint_handler) + + while 1: + + demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) + + demo.launch( + share=cmd_opts.share, + server_name="0.0.0.0" if cmd_opts.listen else None, + server_port=cmd_opts.port, + debug=cmd_opts.gradio_debug, + auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, + inbrowser=cmd_opts.autolaunch, + prevent_thread_lock=True + ) + + while 1: + time.sleep(0.5) + if getattr(demo, 'do_restart', False): + time.sleep(0.5) + demo.close() + time.sleep(0.5) + break + + sd_samplers.set_samplers() + + print('Reloading Custom Scripts') + modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) + print('Reloading modules: modules.ui') + importlib.reload(modules.ui) + print('Restarting Gradio') + + + +if __name__ == "__main__": + webui() \ No newline at end of file -- cgit v1.2.3 From 433a7525c1f5eb5963340e0cc45d31038ede3f7e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 18 Oct 2022 15:18:02 +0300 Subject: remove shared option for update check (because it is not an argument of webui) have launch.py examine both COMMANDLINE_ARGS as well as argv for its arguments --- launch.py | 19 +++++++++---------- modules/shared.py | 1 - 2 files changed, 9 insertions(+), 11 deletions(-) (limited to 'modules/shared.py') diff --git a/launch.py b/launch.py index 9d9dd1b6..7b15e78e 100644 --- a/launch.py +++ b/launch.py @@ -127,13 +127,14 @@ def prepare_enviroment(): codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") - args = shlex.split(commandline_args) + sys.argv += shlex.split(commandline_args) - args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test') - args, reinstall_xformers = extract_arg(args, '--reinstall-xformers') - xformers = '--xformers' in args - deepdanbooru = '--deepdanbooru' in args - ngrok = '--ngrok' in args + sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test') + sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers') + sys.argv, update_check = extract_arg(sys.argv, '--update-check') + xformers = '--xformers' in sys.argv + deepdanbooru = '--deepdanbooru' in sys.argv + ngrok = '--ngrok' in sys.argv try: commit = run(f"{git} rev-parse HEAD").strip() @@ -180,12 +181,10 @@ def prepare_enviroment(): run_pip(f"install -r {requirements_file}", "requirements for Web UI") - sys.argv += args - - if '--update-check' in args: + if update_check: version_check(commit) - if "--exit" in args: + if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) diff --git a/modules/shared.py b/modules/shared.py index 50dc46ae..c0d87168 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -76,7 +76,6 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) -parser.add_argument("--update-check", action='store_true', help="enable http check to confirm that the currently running version is the most recent release.", default=False) cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.3 From 6021f7a75f7b5208a2be15cda5526028152f922d Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 00:51:36 +0900 Subject: add options to custom hypernetwork layer structure --- .gitignore | 1 + modules/hypernetworks/hypernetwork.py | 88 ++++++++++++++++++++++++++--------- modules/shared.py | 4 +- webui.py | 6 ++- 4 files changed, 75 insertions(+), 24 deletions(-) (limited to 'modules/shared.py') diff --git a/.gitignore b/.gitignore index 69785b3e..4794865c 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ __pycache__ notification.mp3 /SwinIR /textual_inversion +/hypernetwork diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 4905710e..cadb9911 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -1,52 +1,98 @@ +import csv import datetime import glob import html import os import sys import traceback -import tqdm -import csv +import modules.textual_inversion.dataset import torch - -from ldm.util import default -from modules import devices, shared, processing, sd_models -import torch -from torch import einsum +import tqdm from einops import rearrange, repeat -import modules.textual_inversion.dataset +from ldm.util import default +from modules import devices, processing, sd_models, shared from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler +from torch import einsum + + +def parse_layer_structure(dim, state_dict): + i = 0 + res = [1] + while (key := "linear.{}.weight".format(i)) in state_dict: + weight = state_dict[key] + res.append(len(weight) // dim) + i += 1 + return res class HypernetworkModule(torch.nn.Module): multiplier = 1.0 + layer_structure = None + add_layer_norm = False def __init__(self, dim, state_dict=None): super().__init__() + if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None: + layer_structure = (1, 2, 1) + else: + if self.layer_structure is not None: + assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" + assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" + layer_structure = self.layer_structure + else: + layer_structure = parse_layer_structure(dim, state_dict) + + linears = [] + for i in range(len(layer_structure) - 1): + linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) + if self.add_layer_norm: + linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - self.linear1 = torch.nn.Linear(dim, dim * 2) - self.linear2 = torch.nn.Linear(dim * 2, dim) + self.linear = torch.nn.Sequential(*linears) if state_dict is not None: - self.load_state_dict(state_dict, strict=True) + try: + self.load_state_dict(state_dict) + except RuntimeError: + self.try_load_previous(state_dict) else: - - self.linear1.weight.data.normal_(mean=0.0, std=0.01) - self.linear1.bias.data.zero_() - self.linear2.weight.data.normal_(mean=0.0, std=0.01) - self.linear2.bias.data.zero_() + for layer in self.linear: + layer.weight.data.normal_(mean = 0.0, std = 0.01) + layer.bias.data.zero_() self.to(devices.device) + def try_load_previous(self, state_dict): + states = self.state_dict() + states['linear.0.bias'].copy_(state_dict['linear1.bias']) + states['linear.0.weight'].copy_(state_dict['linear1.weight']) + states['linear.1.bias'].copy_(state_dict['linear2.bias']) + states['linear.1.weight'].copy_(state_dict['linear2.weight']) + def forward(self, x): - return x + (self.linear2(self.linear1(x))) * self.multiplier + return x + self.linear(x) * self.multiplier + + def trainables(self): + res = [] + for layer in self.linear: + res += [layer.weight, layer.bias] + return res def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength +def apply_layer_structure(value=None): + HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure + + +def apply_layer_norm(value=None): + HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm + + class Hypernetwork: filename = None name = None @@ -68,7 +114,7 @@ class Hypernetwork: for k, layers in self.layers.items(): for layer in layers: layer.train() - res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias] + res += layer.trainables() return res @@ -226,7 +272,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) - + assert ds.length > 1, "Dataset should contain more than 1 images" if unload: shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu) @@ -261,7 +307,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log with torch.autocast("cuda"): c = stack_conds([entry.cond for entry in entries]).to(devices.device) -# c = torch.vstack([entry.cond for entry in entries]).to(devices.device) + c = torch.vstack([entry.cond for entry in entries]).to(devices.device) x = torch.stack([entry.latent for entry in entries]).to(devices.device) loss = shared.sd_model(x, c)[0] del x @@ -283,7 +329,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{mean_loss:.7f}", - "learn_rate": scheduler.learn_rate + "learn_rate": f"{scheduler.learn_rate:.7f}" }) if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: diff --git a/modules/shared.py b/modules/shared.py index c0d87168..c87ce70e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization +from modules import sd_models, sd_samplers, localization from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -258,6 +258,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), + "sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}), + "sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), diff --git a/webui.py b/webui.py index fe0ce321..86e98ad0 100644 --- a/webui.py +++ b/webui.py @@ -86,11 +86,13 @@ def initialize(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) + shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure) + shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm) def webui(): initialize() - + # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') @@ -101,7 +103,7 @@ def webui(): while 1: demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) - + app, local_url, share_url = demo.launch( share=cmd_opts.share, server_name="0.0.0.0" if cmd_opts.listen else None, -- cgit v1.2.3 From 7f2095c6c8db82a5c9cd7c7177f6ba856a2cc676 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 01:01:22 +0900 Subject: update --- modules/shared.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index c87ce70e..6b6d5c41 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,7 +13,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_models, sd_samplers, localization +from modules import sd_samplers, sd_models, localization from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -135,7 +135,7 @@ class State: self.job_no += 1 self.sampling_step = 0 self.current_image_sampling_step = 0 - + def get_job_timestamp(self): return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? -- cgit v1.2.3 From 538bc89c269743e56b07ef2b471d1ce0a39b6776 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Wed, 19 Oct 2022 11:27:51 +0800 Subject: Image browser improved --- javascript/images_history.js | 87 ++++++++++++++-------------- modules/images_history.py | 135 ++++++++++++++++++++++++------------------- modules/shared.py | 5 ++ modules/ui.py | 2 +- 4 files changed, 123 insertions(+), 106 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/images_history.js b/javascript/images_history.js index 182d730b..c9aa76f8 100644 --- a/javascript/images_history.js +++ b/javascript/images_history.js @@ -17,14 +17,6 @@ var images_history_click_image = function(){ images_history_set_image_info(this); } -var images_history_click_tab = function(){ - var tabs_box = gradioApp().getElementById("images_history_tab"); - if (!tabs_box.classList.contains(this.getAttribute("tabname"))) { - gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click(); - tabs_box.classList.add(this.getAttribute("tabname")) - } -} - function images_history_disabled_del(){ gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ btn.setAttribute('disabled','disabled'); @@ -145,57 +137,64 @@ function images_history_enable_del_buttons(){ } function images_history_init(){ - // var loaded = gradioApp().getElementById("images_history_reconstruct_directory") - // if (loaded){ - // var init_status = loaded.querySelector("input").checked - if (gradioApp().getElementById("images_history_finish_render")){ + var tabnames = gradioApp().getElementById("images_history_tabnames_list") + if (tabnames){ + images_history_tab_list = tabnames.querySelector("textarea").value.split(",") for (var i in images_history_tab_list ){ - tab = images_history_tab_list[i]; + var tab = images_history_tab_list[i]; gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor"); gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index"); gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button"); - gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); - + gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); + gradioApp().getElementById(tab + "_images_history_start").setAttribute("style","padding:20px;font-size:25px"); + } + + //preload + if (gradioApp().getElementById("images_history_preload").querySelector("input").checked ){ + var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div"); + tabs_box.setAttribute("id", "images_history_tab"); + var tab_btns = tabs_box.querySelectorAll("button"); + for (var i in images_history_tab_list){ + var tabname = images_history_tab_list[i] + tab_btns[i].setAttribute("tabname", tabname); + tab_btns[i].addEventListener('click', function(){ + var tabs_box = gradioApp().getElementById("images_history_tab"); + if (!tabs_box.classList.contains(this.getAttribute("tabname"))) { + gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click(); + tabs_box.classList.add(this.getAttribute("tabname")) + } + }); + } + tab_btns[0].click() } - var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div"); - tabs_box.setAttribute("id", "images_history_tab"); - var tab_btns = tabs_box.querySelectorAll("button"); - - for (var i in images_history_tab_list){ - var tabname = images_history_tab_list[i] - tab_btns[i].setAttribute("tabname", tabname); - // if (!init_status){ - // tab_btns[i].addEventListener('click', images_history_click_tab); - // } - tab_btns[i].addEventListener('click', images_history_click_tab); - } } else { setTimeout(images_history_init, 500); } } -var images_history_tab_list = ["custom", "txt2img", "img2img", "extras", "saved"]; +var images_history_tab_list = ""; setTimeout(images_history_init, 500); document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ - for (var i in images_history_tab_list ){ - let tabname = images_history_tab_list[i] - var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); - buttons.forEach(function(bnt){ - bnt.addEventListener('click', images_history_click_image, true); - }); - - var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); - if (cls_btn){ - cls_btn.addEventListener('click', function(){ - gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled'); - }, false); - } - - } + if (images_history_tab_list != ""){ + for (var i in images_history_tab_list ){ + let tabname = images_history_tab_list[i] + var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); + buttons.forEach(function(bnt){ + bnt.addEventListener('click', images_history_click_image, true); + }); + + var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); + if (cls_btn){ + cls_btn.addEventListener('click', function(){ + gradioApp().getElementById(tabname + '_images_history_renew_page').click(); + }, false); + } + + } + } }); mutationObserver.observe(gradioApp(), { childList:true, subtree:true }); - }); diff --git a/modules/images_history.py b/modules/images_history.py index a40cdc0e..78fd0543 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -4,7 +4,9 @@ import time import hashlib import gradio system_bak_path = "webui_log_and_bak" -browser_tabname = "custom" +custom_tab_name = "custom fold" +faverate_tab_name = "favorites" +tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name] def is_valid_date(date): try: time.strptime(date, "%Y%m%d") @@ -122,7 +124,6 @@ def archive_images(dir_name, date_to): else: filenames = traverse_all_files(dir_name, filenames) total_num = len(filenames) - batch_count = len(filenames) + 1 // batch_size + 1 tmparray = [(os.path.getmtime(file), file) for file in filenames ] date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 filenames = [] @@ -146,10 +147,12 @@ def archive_images(dir_name, date_to): last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) date = date[:4] + "/" + date[4:6] + "/" + date[6:8] date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8] - load_info = f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" + load_info = "
" + load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" + load_info += "
" _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) return ( - gradio.Dropdown.update(choices=date_list, value=date_to), + date_to, load_info, filenames, 1, @@ -158,7 +161,7 @@ def archive_images(dir_name, date_to): "", visible_num, last_date_from, - #gradio.update(visible=batch_count > 1) + gradio.update(visible=total_num > num) ) def delete_image(delete_num, name, filenames, image_index, visible_num): @@ -209,7 +212,7 @@ def get_recent_images(page_index, step, filenames): visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num return page_index, image_list, "", "", visible_num -def newest_click(date_to): +def loac_batch_click(date_to): if date_to is None: return time.strftime("%Y%m%d",time.localtime(time.time())), [] else: @@ -248,7 +251,7 @@ def page_index_change(page_index, filenames): def show_image_info(tabname_box, num, page_index, filenames): file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] - tm = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + tm = "
" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "
" return file, tm, num, file def enable_page_buttons(): @@ -268,9 +271,9 @@ def change_dir(img_dir, date_to): warning = "The format of the directory is incorrect" if warning is None: today = time.strftime("%Y%m%d",time.localtime(time.time())) - return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today + return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True) else: - return gradio.update(visible=True), gradio.update(visible=False), warning, date_to + return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False) def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): custom_dir = False @@ -280,7 +283,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): dir_name = opts.outdir_img2img_samples elif tabname == "extras": dir_name = opts.outdir_extras_samples - elif tabname == "saved": + elif tabname == faverate_tab_name: dir_name = opts.outdir_save else: custom_dir = True @@ -295,22 +298,26 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): os.makedirs(dir_name) with gr.Column() as page_panel: - with gr.Row(): - img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) + with gr.Row(): + with gr.Column(scale=1, visible=not custom_dir) as load_batch_box: + load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True) + with gr.Column(scale=4): + with gr.Row(): + img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) + with gr.Row(): + with gr.Column(visible=False, scale=1) as batch_panel: + with gr.Row(): + forward = gr.Button('Prev batch') + backward = gr.Button('Next batch') + with gr.Column(scale=3): + load_info = gr.HTML(visible=not custom_dir) with gr.Row(visible=False) as warning: warning_box = gr.Textbox("Message", interactive=False) with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: - with gr.Column(scale=2): - with gr.Row() as batch_panel: - forward = gr.Button('Forward') - date_to = gr.Dropdown(label="Date to") - backward = gr.Button('Backward') - newest = gr.Button('Reload', elem_id=tabname + "_images_history_start") - with gr.Row(): - load_info = gr.Textbox(show_label=False, interactive=False) - with gr.Row(visible=False) as turn_page_buttons: - renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") + with gr.Column(scale=2): + with gr.Row(visible=True) as turn_page_buttons: + #date_to = gr.Dropdown(label="Date to") first_page = gr.Button('First Page') prev_page = gr.Button('Prev Page') page_index = gr.Number(value=1, label="Page Index") @@ -322,50 +329,54 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - with gr.Column(): - with gr.Row(): - if tabname != "saved": - save_btn = gr.Button('Save') - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') + with gr.Column(): with gr.Row(): with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False) + img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6) + gr.HTML("
") img_file_name = gr.Textbox(value="", label="File Name", interactive=False) - img_file_time= gr.Textbox(value="", label="Create Time", interactive=False) - + img_file_time= gr.HTML() + with gr.Row(): + if tabname != faverate_tab_name: + save_btn = gr.Button('Collect') + pnginfo_send_to_txt2img = gr.Button('Send to txt2img') + pnginfo_send_to_img2img = gr.Button('Send to img2img') + - # hiden items - with gr.Row(visible=False): - visible_img_num = gr.Number() - date_to_recorder = gr.State([]) - last_date_from = gr.Textbox() - tabname_box = gr.Textbox(tabname) - image_index = gr.Textbox(value=-1) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") - filenames = gr.State() - all_images_list = gr.State() - hidden = gr.Image(type="pil") - info1 = gr.Textbox() - info2 = gr.Textbox() - - img_path.submit(change_dir, inputs=[img_path, date_to], outputs=[warning, main_panel, warning_box, date_to]) - #change date - change_date_output = [date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from] + # hiden items + with gr.Row(visible=False): + renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") + batch_date_to = gr.Textbox(label="Date to") + visible_img_num = gr.Number() + date_to_recorder = gr.State([]) + last_date_from = gr.Textbox() + tabname_box = gr.Textbox(tabname) + image_index = gr.Textbox(value=-1) + set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") + filenames = gr.State() + all_images_list = gr.State() + hidden = gr.Image(type="pil") + info1 = gr.Textbox() + info2 = gr.Textbox() + + img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info]) + + #change batch + change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel] - date_to.change(archive_images, inputs=[img_path, date_to], outputs=change_date_output) - date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) - date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") + batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output) + batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) + batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - newest.click(newest_click, inputs=[date_to], outputs=[date_to, date_to_recorder]) - forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) - backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[date_to, date_to_recorder]) + load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder]) + forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) + backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) #delete delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None) - if tabname != "saved": + if tabname != faverate_tab_name: save_btn.click(save_image, inputs=[img_file_name], outputs=None) #turn page @@ -394,18 +405,20 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): -def create_history_tabs(gr, sys_opts, run_pnginfo, switch_dict): +def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict): global opts; opts = sys_opts loads_files_num = int(opts.images_history_num_per_page) num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) + if cmp_ops.browse_all_images: + tabs_list.append(custom_tab_name) with gr.Blocks(analytics_enabled=False) as images_history: with gr.Tabs() as tabs: - for tab in [browser_tabname, "txt2img", "img2img", "extras", "saved"]: + for tab in tabs_list: with gr.Tab(tab): with gr.Blocks(analytics_enabled=False) : - show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - #gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_reconstruct_directory", visible=False) - gradio.Checkbox(opts.images_history_reconstruct_directory, elem_id="images_history_finish_render", visible=False) - + show_images_history(gr, opts, tab, run_pnginfo, switch_dict) + gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False) + gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False) + return images_history diff --git a/modules/shared.py b/modules/shared.py index 1811018d..4d735414 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,6 +74,10 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) +parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False) + + +cmd_opts = parser.parse_args() cmd_opts = parser.parse_args() @@ -311,6 +315,7 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" options_templates.update(options_section(('images-history', "Images Browser"), { #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), + "images_history_preload": OptionInfo(False, "Preload images at startup"), "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"), "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "), "images_history_grid_num": OptionInfo(6, "Number of grids in each row"), diff --git a/modules/ui.py b/modules/ui.py index 85abac4d..88f46659 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1150,7 +1150,7 @@ def create_ui(wrap_gradio_gpu_call): "i2i":img2img_paste_fields } - images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): -- cgit v1.2.3 From 42fbda83bb9830af18187fddb50c1bedd01da502 Mon Sep 17 00:00:00 2001 From: discus0434 Date: Wed, 19 Oct 2022 14:30:33 +0000 Subject: layer options moves into create hnet ui --- modules/hypernetworks/hypernetwork.py | 64 +++++++++++++++++------------------ modules/hypernetworks/ui.py | 9 +++-- modules/shared.py | 2 -- modules/ui.py | 8 +++-- webui.py | 8 ++--- 5 files changed, 48 insertions(+), 43 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 583ada31..7d519cd9 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -19,37 +19,21 @@ from modules.textual_inversion import textual_inversion from modules.textual_inversion.learn_schedule import LearnRateScheduler -def parse_layer_structure(dim, state_dict): - i = 0 - res = [1] - while (key := "linear.{}.weight".format(i)) in state_dict: - weight = state_dict[key] - res.append(len(weight) // dim) - i += 1 - return res - - class HypernetworkModule(torch.nn.Module): multiplier = 1.0 - layer_structure = None - add_layer_norm = False - def __init__(self, dim, state_dict=None): + def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False): super().__init__() - if (state_dict is None or 'linear.0.weight' not in state_dict) and self.layer_structure is None: - layer_structure = (1, 2, 1) + if layer_structure is not None: + assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" + assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" else: - if self.layer_structure is not None: - assert self.layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" - assert self.layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - layer_structure = self.layer_structure - else: - layer_structure = parse_layer_structure(dim, state_dict) + layer_structure = parse_layer_structure(dim, state_dict) linears = [] for i in range(len(layer_structure) - 1): linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - if self.add_layer_norm: + if add_layer_norm: linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) self.linear = torch.nn.Sequential(*linears) @@ -77,38 +61,47 @@ class HypernetworkModule(torch.nn.Module): return x + self.linear(x) * self.multiplier def trainables(self): - res = [] + layer_structure = [] for layer in self.linear: - res += [layer.weight, layer.bias] - return res + layer_structure += [layer.weight, layer.bias] + return layer_structure def apply_strength(value=None): HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength -def apply_layer_structure(value=None): - HypernetworkModule.layer_structure = value if value is not None else shared.opts.sd_hypernetwork_layer_structure +def parse_layer_structure(dim, state_dict): + i = 0 + layer_structure = [1] + while (key := "linear.{}.weight".format(i)) in state_dict: + weight = state_dict[key] + layer_structure.append(len(weight) // dim) + i += 1 -def apply_layer_norm(value=None): - HypernetworkModule.add_layer_norm = value if value is not None else shared.opts.sd_hypernetwork_add_layer_norm + return layer_structure class Hypernetwork: filename = None name = None - def __init__(self, name=None, enable_sizes=None): + def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False): self.filename = None self.name = name self.layers = {} self.step = 0 self.sd_checkpoint = None self.sd_checkpoint_name = None + self.layer_structure = layer_structure + self.add_layer_norm = add_layer_norm for size in enable_sizes or []: - self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) + self.layers[size] = ( + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), + HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm), + ) def weights(self): res = [] @@ -128,6 +121,8 @@ class Hypernetwork: state_dict['step'] = self.step state_dict['name'] = self.name + state_dict['layer_structure'] = self.layer_structure + state_dict['is_layer_norm'] = self.add_layer_norm state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name @@ -142,10 +137,15 @@ class Hypernetwork: for size, sd in state_dict.items(): if type(size) == int: - self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) + self.layers[size] = ( + HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]), + HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]), + ) self.name = state_dict.get('name', self.name) self.step = state_dict.get('step', 0) + self.layer_structure = state_dict.get('layer_structure', None) + self.add_layer_norm = state_dict.get('is_layer_norm', False) self.sd_checkpoint = state_dict.get('sd_checkpoint', None) self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None) diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index dfa599af..7e8ea95e 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,11 +9,16 @@ from modules import sd_hijack, shared, devices from modules.hypernetworks import hypernetwork -def create_hypernetwork(name, enable_sizes): +def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False): fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") assert not os.path.exists(fn), f"file {fn} already exists" - hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes]) + hypernet = modules.hypernetworks.hypernetwork.Hypernetwork( + name=name, + enable_sizes=[int(x) for x in enable_sizes], + layer_structure=layer_structure, + add_layer_norm=add_layer_norm, + ) hypernet.save(fn) shared.reload_hypernetworks() diff --git a/modules/shared.py b/modules/shared.py index 0540cae9..faede821 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -260,8 +260,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), - "sd_hypernetwork_layer_structure": OptionInfo(None, "Hypernetwork layer structure Default: (1,2,1).", gr.Dropdown, lambda: {"choices": [(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]}), - "sd_hypernetwork_add_layer_norm": OptionInfo(False, "Add layer normalization to hypernetwork architecture."), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), diff --git a/modules/ui.py b/modules/ui.py index ca46343f..d9ee462f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -458,14 +458,14 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=2, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)" ) with gr.Row(): with gr.Column(scale=80): with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, + negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)" ) @@ -1198,6 +1198,8 @@ def create_ui(wrap_gradio_gpu_call): with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) + new_hypernetwork_layer_structure = gr.Dropdown(label="Hypernetwork layer structure", choices=[(1, 2, 1), (1, 2, 2, 1), (1, 2, 4, 2, 1)]) + new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") with gr.Row(): with gr.Column(scale=3): @@ -1280,6 +1282,8 @@ def create_ui(wrap_gradio_gpu_call): inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, + new_hypernetwork_layer_structure, + new_hypernetwork_add_layer_norm, ], outputs=[ train_hypernetwork_name, diff --git a/webui.py b/webui.py index c7260c7a..177bef74 100644 --- a/webui.py +++ b/webui.py @@ -85,9 +85,7 @@ def initialize(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) - shared.opts.onchange("sd_hypernetwork_layer_structure", modules.hypernetworks.hypernetwork.apply_layer_structure) - shared.opts.onchange("sd_hypernetwork_add_layer_norm", modules.hypernetworks.hypernetwork.apply_layer_norm) - + # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') @@ -142,7 +140,7 @@ def webui(launch_api=False): create_api(app) wait_on_server(demo) - + sd_samplers.set_samplers() print('Reloading Custom Scripts') @@ -160,4 +158,4 @@ if __name__ == "__main__": if cmd_opts.nowebui: api_only() else: - webui(cmd_opts.api) \ No newline at end of file + webui(cmd_opts.api) -- cgit v1.2.3 From d07cb46f34b3d9fe7a78b102f899ebef352ea56b Mon Sep 17 00:00:00 2001 From: yfszzx Date: Thu, 20 Oct 2022 23:58:52 +0800 Subject: inspiration pull request --- .gitignore | 3 +- javascript/imageviewer.js | 1 - javascript/inspiration.js | 42 ++++++++++++ modules/inspiration.py | 122 +++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + modules/ui.py | 13 ++-- scripts/create_inspiration_images.py | 45 +++++++++++++ webui.py | 5 ++ 8 files changed, 225 insertions(+), 7 deletions(-) create mode 100644 javascript/inspiration.js create mode 100644 modules/inspiration.py create mode 100644 scripts/create_inspiration_images.py (limited to 'modules/shared.py') diff --git a/.gitignore b/.gitignore index f9c3357c..434d50b7 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,5 @@ __pycache__ notification.mp3 /SwinIR /textual_inversion -.vscode \ No newline at end of file +.vscode +/inspiration \ No newline at end of file diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 9e380c65..d4ab6984 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -116,7 +116,6 @@ function showGalleryImage() { e.dataset.modded = true; if(e && e.parentElement.tagName == 'DIV'){ e.style.cursor='pointer' - e.style.userSelect='none' e.addEventListener('click', function (evt) { if(!opts.js_modal_lightbox) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) diff --git a/javascript/inspiration.js b/javascript/inspiration.js new file mode 100644 index 00000000..e1c0e114 --- /dev/null +++ b/javascript/inspiration.js @@ -0,0 +1,42 @@ +function public_image_index_in_gallery(item, gallery){ + var index; + var i = 0; + gallery.querySelectorAll("img").forEach(function(e){ + if (e == item) + index = i; + i += 1; + }); + return index; +} + +function inspiration_selected(name, types, name_list){ + var btn = gradioApp().getElementById("inspiration_select_button") + return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index"), types]; +} +var inspiration_image_click = function(){ + var index = public_image_index_in_gallery(this, gradioApp().getElementById("inspiration_gallery")); + var btn = gradioApp().getElementById("inspiration_select_button") + btn.setAttribute("img-index", index) + setTimeout(function(btn){btn.click();}, 10, btn) +} + +document.addEventListener("DOMContentLoaded", function() { + var mutationObserver = new MutationObserver(function(m){ + var gallery = gradioApp().getElementById("inspiration_gallery") + if (gallery) { + var node = gallery.querySelector(".absolute.backdrop-blur.h-full") + if (node) { + node.style.display = "None"; //parentNode.removeChild(node) + } + + gallery.querySelectorAll('img').forEach(function(e){ + e.onclick = inspiration_image_click + }) + + } + + + }); + mutationObserver.observe( gradioApp(), { childList:true, subtree:true }); + +}); diff --git a/modules/inspiration.py b/modules/inspiration.py new file mode 100644 index 00000000..456bfcb5 --- /dev/null +++ b/modules/inspiration.py @@ -0,0 +1,122 @@ +import os +import random +import gradio +inspiration_path = "inspiration" +inspiration_system_path = os.path.join(inspiration_path, "system") +def read_name_list(file): + if not os.path.exists(file): + return [] + f = open(file, "r") + ret = [] + line = f.readline() + while len(line) > 0: + line = line.rstrip("\n") + ret.append(line) + print(ret) + return ret + +def save_name_list(file, name): + print(file) + f = open(file, "a") + f.write(name + "\n") + +def get_inspiration_images(source, types): + path = os.path.join(inspiration_path , types) + if source == "Favorites": + names = read_name_list(os.path.join(inspiration_system_path, types + "_faverites.txt")) + names = random.sample(names, 25) + elif source == "Abandoned": + names = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) + names = random.sample(names, 25) + elif source == "Exclude abandoned": + abondened = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) + all_names = os.listdir(path) + names = [] + while len(names) < 25: + name = random.choice(all_names) + if name not in abondened: + names.append(name) + else: + names = random.sample(os.listdir(path), 25) + names = random.sample(names, 25) + image_list = [] + for a in names: + image_path = os.path.join(path, a) + images = os.listdir(image_path) + image_list.append(os.path.join(image_path, random.choice(images))) + return image_list, names + +def select_click(index, types, name_list): + name = name_list[int(index)] + path = os.path.join(inspiration_path, types, name) + images = os.listdir(path) + return name, [os.path.join(path, x) for x in images] + +def give_up_click(name, types): + file = os.path.join(inspiration_system_path, types + "_abandoned.txt") + name_list = read_name_list(file) + if name not in name_list: + save_name_list(file, name) + +def collect_click(name, types): + file = os.path.join(inspiration_system_path, types + "_faverites.txt") + print(file) + name_list = read_name_list(file) + print(name_list) + if name not in name_list: + save_name_list(file, name) + +def moveout_click(name, types): + file = os.path.join(inspiration_system_path, types + "_faverites.txt") + name_list = read_name_list(file) + if name not in name_list: + save_name_list(file, name) + +def source_change(source): + if source == "Abandoned" or source == "Favorites": + return gradio.Button.update(visible=True, value=f"Move out {source}") + else: + return gradio.Button.update(visible=False) + +def ui(gr, opts): + with gr.Blocks(analytics_enabled=False) as inspiration: + flag = os.path.exists(inspiration_path) + if flag: + types = os.listdir(inspiration_path) + types = [x for x in types if x != "system"] + flag = len(types) > 0 + if not flag: + os.mkdir(inspiration_path) + gr.HTML(""" +
" + """) + return inspiration + if not os.path.exists(inspiration_system_path): + os.mkdir(inspiration_system_path) + gallery, names = get_inspiration_images("Exclude abandoned", types[0]) + with gr.Row(): + with gr.Column(scale=2): + inspiration_gallery = gr.Gallery(gallery, show_label=False, elem_id="inspiration_gallery").style(grid=5, height='auto') + with gr.Column(scale=1): + types = gr.Dropdown(choices=types, value=types[0], label="Type", visible=len(types) > 1) + with gr.Row(): + source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") + get_inspiration = gr.Button("Get inspiration") + name = gr.Textbox(show_label=False, interactive=False) + with gr.Row(): + send_to_txt2img = gr.Button('to txt2img') + send_to_img2img = gr.Button('to img2img') + style_gallery = gr.Gallery(show_label=False, elem_id="inspiration_style_gallery").style(grid=2, height='auto') + + collect = gr.Button('Collect') + give_up = gr.Button("Don't show any more") + moveout = gr.Button("Move out", visible=False) + with gr.Row(): + select_button = gr.Button('set button', elem_id="inspiration_select_button") + name_list = gr.State(names) + source.change(source_change, inputs=[source], outputs=[moveout]) + get_inspiration.click(get_inspiration_images, inputs=[source, types], outputs=[inspiration_gallery, name_list]) + select_button.click(select_click, _js="inspiration_selected", inputs=[name, types, name_list], outputs=[name, style_gallery]) + give_up.click(give_up_click, inputs=[name, types], outputs=None) + collect.click(collect_click, inputs=[name, types], outputs=None) + return inspiration diff --git a/modules/shared.py b/modules/shared.py index faede821..ae033710 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -78,6 +78,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/modules/ui.py b/modules/ui.py index a2dbd41e..6a0a3c3b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -41,7 +41,8 @@ from modules import prompt_parser from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.images_history as img_his +import modules.images_history as images_history +import modules.inspiration as inspiration # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -1082,9 +1083,9 @@ def create_ui(wrap_gradio_gpu_call): upscaling_resize_w = gr.Number(label="Width", value=512, precision=0) upscaling_resize_h = gr.Number(label="Height", value=512, precision=0) upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) - + with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers] , value=shared.sd_upscalers[0].name, type="index") with gr.Group(): extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") @@ -1178,7 +1179,8 @@ def create_ui(wrap_gradio_gpu_call): "i2i":img2img_paste_fields } - images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + browser_interface = images_history.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) + inspiration_interface = inspiration.ui(gr, opts) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): @@ -1595,7 +1597,8 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (images_history, "History", "images_history"), + (browser_interface, "History", "images_history"), + (inspiration_interface, "Inspiration", "inspiration"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), (settings_interface, "Settings", "settings"), diff --git a/scripts/create_inspiration_images.py b/scripts/create_inspiration_images.py new file mode 100644 index 00000000..6a20def8 --- /dev/null +++ b/scripts/create_inspiration_images.py @@ -0,0 +1,45 @@ +import csv, os, shutil +import modules.scripts as scripts +from modules import processing, shared, sd_samplers, images +from modules.processing import Processed + + +class Script(scripts.Script): + def title(self): + return "Create artists style image" + + def show(self, is_img2img): + return not is_img2img + + def ui(self, is_img2img): + return [] + def show(self, is_img2img): + return not is_img2img + + def run(self, p): #, max_snapshoots_num): + path = os.path.join("style_snapshoot", "artist") + if not os.path.exists(path): + os.makedirs(path) + p.do_not_save_samples = True + p.do_not_save_grid = True + p.negative_prompt = "portrait photo" + f = open('artists.csv') + f_csv = csv.reader(f) + for row in f_csv: + name = row[0] + artist_path = os.path.join(path, name) + if not os.path.exists(artist_path): + os.mkdir(artist_path) + if len(os.listdir(artist_path)) > 0: + continue + print(name) + p.prompt = name + processed = processing.process_images(p) + for img in processed.images: + i = 0 + filename = os.path.join(artist_path, format(0, "03d") + ".jpg") + while os.path.exists(filename): + i += 1 + filename = os.path.join(artist_path, format(i, "03d") + ".jpg") + img.save(filename, quality=70) + return processed diff --git a/webui.py b/webui.py index 177bef74..5923905f 100644 --- a/webui.py +++ b/webui.py @@ -72,6 +72,11 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) def initialize(): + if cmd_opts.ui_debug_mode: + class enmpty(): + name = None + shared.sd_upscalers = [enmpty()] + return modelloader.cleanup_models() modules.sd_models.setup_model() codeformer.setup_model(cmd_opts.codeformer_models_path) -- cgit v1.2.3 From a3b047b7c74dc6ca07f40aee778997fc1889d72f Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Thu, 20 Oct 2022 19:28:58 -0500 Subject: add settings option to toggle button visibility --- javascript/ui.js | 1 - modules/shared.py | 1 + modules/ui.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/ui.js b/javascript/ui.js index 39eae1f7..f19af550 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -163,7 +163,6 @@ function selected_tab_id() { } function trash_prompt(_,_, is_img2img) { -//txt2img_token_button if(!confirm("Delete prompt?")) return false diff --git a/modules/shared.py b/modules/shared.py index faede821..7e9c2696 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -300,6 +300,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "trash_prompt_visible": OptionInfo(True, "Show trash prompt button"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index bde546cc..13c0b4ca 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -509,7 +509,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt") + trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt", visible=opts.trash_prompt_visible) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") -- cgit v1.2.3 From 02e4d4694dd9254a6ca9f05c2eb7b01ea508abc7 Mon Sep 17 00:00:00 2001 From: Rcmcpe Date: Fri, 21 Oct 2022 15:53:35 +0800 Subject: Change option description of unload_models_when_training --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 5c675b80..41d7f08e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -266,7 +266,7 @@ options_templates.update(options_section(('system', "System"), { })) options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), + "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), -- cgit v1.2.3 From bb0f1a2cdae3410a41d06ae878f56e29b8154c41 Mon Sep 17 00:00:00 2001 From: yfszzx Date: Sat, 22 Oct 2022 01:23:00 +0800 Subject: inspiration finished --- javascript/inspiration.js | 27 ++++--- modules/inspiration.py | 192 ++++++++++++++++++++++++++++++---------------- modules/shared.py | 6 ++ modules/ui.py | 2 +- webui.py | 3 +- 5 files changed, 151 insertions(+), 79 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/inspiration.js b/javascript/inspiration.js index e1c0e114..791a80c9 100644 --- a/javascript/inspiration.js +++ b/javascript/inspiration.js @@ -1,25 +1,31 @@ function public_image_index_in_gallery(item, gallery){ + var imgs = gallery.querySelectorAll("img.h-full") var index; var i = 0; - gallery.querySelectorAll("img").forEach(function(e){ + imgs.forEach(function(e){ if (e == item) index = i; i += 1; }); + var num = imgs.length / 2 + index = (index < num) ? index : (index - num) return index; } -function inspiration_selected(name, types, name_list){ +function inspiration_selected(name, name_list){ var btn = gradioApp().getElementById("inspiration_select_button") - return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index"), types]; -} + return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index")]; +} +function inspiration_click_get_button(){ + gradioApp().getElementById("inspiration_get_button").click(); +} var inspiration_image_click = function(){ var index = public_image_index_in_gallery(this, gradioApp().getElementById("inspiration_gallery")); - var btn = gradioApp().getElementById("inspiration_select_button") - btn.setAttribute("img-index", index) - setTimeout(function(btn){btn.click();}, 10, btn) + var btn = gradioApp().getElementById("inspiration_select_button"); + btn.setAttribute("img-index", index); + setTimeout(function(btn){btn.click();}, 10, btn); } - + document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ var gallery = gradioApp().getElementById("inspiration_gallery") @@ -27,11 +33,10 @@ document.addEventListener("DOMContentLoaded", function() { var node = gallery.querySelector(".absolute.backdrop-blur.h-full") if (node) { node.style.display = "None"; //parentNode.removeChild(node) - } - + } gallery.querySelectorAll('img').forEach(function(e){ e.onclick = inspiration_image_click - }) + }); } diff --git a/modules/inspiration.py b/modules/inspiration.py index 456bfcb5..f72ebf3a 100644 --- a/modules/inspiration.py +++ b/modules/inspiration.py @@ -1,122 +1,182 @@ import os import random -import gradio -inspiration_path = "inspiration" -inspiration_system_path = os.path.join(inspiration_path, "system") -def read_name_list(file): +import gradio +from modules.shared import opts +inspiration_system_path = os.path.join(opts.inspiration_dir, "system") +def read_name_list(file, types=None, keyword=None): if not os.path.exists(file): return [] - f = open(file, "r") ret = [] + f = open(file, "r") line = f.readline() while len(line) > 0: line = line.rstrip("\n") - ret.append(line) - print(ret) + if types is not None: + dirname = os.path.split(line) + if dirname[0] in types and keyword in dirname[1]: + ret.append(line) + else: + ret.append(line) + line = f.readline() return ret def save_name_list(file, name): - print(file) - f = open(file, "a") - f.write(name + "\n") + with open(file, "a") as f: + f.write(name + "\n") -def get_inspiration_images(source, types): - path = os.path.join(inspiration_path , types) +def get_types_list(): + files = os.listdir(opts.inspiration_dir) + types = [] + for x in files: + path = os.path.join(opts.inspiration_dir, x) + if x[0] == ".": + continue + if not os.path.isdir(path): + continue + if path == inspiration_system_path: + continue + types.append(x) + return types + +def get_inspiration_images(source, types, keyword): + get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num) if source == "Favorites": - names = read_name_list(os.path.join(inspiration_system_path, types + "_faverites.txt")) - names = random.sample(names, 25) + names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword) + names = random.sample(names, get_num) if len(names) > get_num else names elif source == "Abandoned": - names = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) - names = random.sample(names, 25) - elif source == "Exclude abandoned": - abondened = read_name_list(os.path.join(inspiration_system_path, types + "_abondened.txt")) - all_names = os.listdir(path) - names = [] - while len(names) < 25: - name = random.choice(all_names) - if name not in abondened: - names.append(name) + names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) + print(names) + names = random.sample(names, get_num) if len(names) > get_num else names + elif source == "Exclude abandoned": + abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) + all_names = [] + for tp in types: + name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) + all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + + if len(all_names) > get_num: + names = [] + while len(names) < get_num: + name = random.choice(all_names) + if name not in abandoned: + names.append(name) + else: + names = all_names else: - names = random.sample(os.listdir(path), 25) - names = random.sample(names, 25) + all_names = [] + for tp in types: + name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) + all_names += [os.path.join(tp, x) for x in name_list if keyword in x] + names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names image_list = [] for a in names: - image_path = os.path.join(path, a) + image_path = os.path.join(opts.inspiration_dir, a) images = os.listdir(image_path) - image_list.append(os.path.join(image_path, random.choice(images))) - return image_list, names + image_list.append((os.path.join(image_path, random.choice(images)), a)) + return image_list, names, "" -def select_click(index, types, name_list): +def select_click(index, name_list): name = name_list[int(index)] - path = os.path.join(inspiration_path, types, name) + path = os.path.join(opts.inspiration_dir, name) images = os.listdir(path) - return name, [os.path.join(path, x) for x in images] + return name, [os.path.join(path, x) for x in images], "" -def give_up_click(name, types): - file = os.path.join(inspiration_system_path, types + "_abandoned.txt") +def give_up_click(name): + file = os.path.join(inspiration_system_path, "abandoned.txt") name_list = read_name_list(file) if name not in name_list: save_name_list(file, name) + return "Added to abandoned list" -def collect_click(name, types): - file = os.path.join(inspiration_system_path, types + "_faverites.txt") - print(file) +def collect_click(name): + file = os.path.join(inspiration_system_path, "faverites.txt") name_list = read_name_list(file) - print(name_list) if name not in name_list: save_name_list(file, name) + return "Added to faverite list" -def moveout_click(name, types): - file = os.path.join(inspiration_system_path, types + "_faverites.txt") +def moveout_click(name, source): + if source == "Abandoned": + file = os.path.join(inspiration_system_path, "abandoned.txt") + if source == "Favorites": + file = os.path.join(inspiration_system_path, "faverites.txt") + else: + return None name_list = read_name_list(file) - if name not in name_list: - save_name_list(file, name) + os.remove(file) + with open(file, "a") as f: + for a in name_list: + if a != name: + f.write(a) + return "Moved out {name} from {source} list" def source_change(source): - if source == "Abandoned" or source == "Favorites": - return gradio.Button.update(visible=True, value=f"Move out {source}") + if source in ["Abandoned", "Favorites"]: + return gradio.update(visible=True), [] else: - return gradio.Button.update(visible=False) + return gradio.update(visible=False), [] +def add_to_prompt(name, prompt): + print(name, prompt) + name = os.path.basename(name) + return prompt + "," + name -def ui(gr, opts): +def ui(gr, opts, txt2img_prompt, img2img_prompt): with gr.Blocks(analytics_enabled=False) as inspiration: - flag = os.path.exists(inspiration_path) + flag = os.path.exists(opts.inspiration_dir) if flag: - types = os.listdir(inspiration_path) - types = [x for x in types if x != "system"] + types = get_types_list() flag = len(types) > 0 - if not flag: - os.mkdir(inspiration_path) + else: + os.makedirs(opts.inspiration_dir) + if not flag: gr.HTML(""" -
" +

To activate inspiration function, you need get "inspiration" images first.


+ You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
+ https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
+ download these files, and select these files in the "Create inspiration images" script UI
+ There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style +
I suggest at least four images for each


+

You can also download generated pictures from here:


+ https://huggingface.co/datasets/yfszzx/inspiration
+ unzip the file to the project directory of webui
+ and restart webui, and enjoy the joy of creation!
""") return inspiration if not os.path.exists(inspiration_system_path): os.mkdir(inspiration_system_path) - gallery, names = get_inspiration_images("Exclude abandoned", types[0]) with gr.Row(): with gr.Column(scale=2): - inspiration_gallery = gr.Gallery(gallery, show_label=False, elem_id="inspiration_gallery").style(grid=5, height='auto') + inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto') with gr.Column(scale=1): - types = gr.Dropdown(choices=types, value=types[0], label="Type", visible=len(types) > 1) + print(types) + types = gr.CheckboxGroup(choices=types, value=types) + keyword = gr.Textbox("", label="Key word") with gr.Row(): source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") - get_inspiration = gr.Button("Get inspiration") + get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") name = gr.Textbox(show_label=False, interactive=False) with gr.Row(): send_to_txt2img = gr.Button('to txt2img') send_to_img2img = gr.Button('to img2img') - style_gallery = gr.Gallery(show_label=False, elem_id="inspiration_style_gallery").style(grid=2, height='auto') - + style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') collect = gr.Button('Collect') - give_up = gr.Button("Don't show any more") + give_up = gr.Button("Don't show again") moveout = gr.Button("Move out", visible=False) - with gr.Row(): + warning = gr.HTML() + with gr.Row(visible=False): select_button = gr.Button('set button', elem_id="inspiration_select_button") - name_list = gr.State(names) - source.change(source_change, inputs=[source], outputs=[moveout]) - get_inspiration.click(get_inspiration_images, inputs=[source, types], outputs=[inspiration_gallery, name_list]) - select_button.click(select_click, _js="inspiration_selected", inputs=[name, types, name_list], outputs=[name, style_gallery]) - give_up.click(give_up_click, inputs=[name, types], outputs=None) - collect.click(collect_click, inputs=[name, types], outputs=None) + name_list = gr.State() + + get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list, keyword]) + source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) + source.change(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) + select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning]) + give_up.click(give_up_click, inputs=[name], outputs=[warning]) + collect.click(collect_click, inputs=[name], outputs=[warning]) + moveout.click(moveout_click, inputs=[name, source], outputs=[warning]) + send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt]) + send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt]) + send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) + send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None) return inspiration diff --git a/modules/shared.py b/modules/shared.py index ae033710..564b1b8d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -316,6 +316,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) +options_templates.update(options_section(('inspiration', "Inspiration"), { + "inspiration_dir": OptionInfo("inspiration", "Directory of inspiration", component_args=hide_dirs), + "inspiration_max_samples": OptionInfo(4, "Maximum number of samples, used to determine which folders to skip when continue running the create script", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), + "inspiration_rows_num": OptionInfo(4, "Rows of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), + "inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), +})) class Options: data = None diff --git a/modules/ui.py b/modules/ui.py index 6a0a3c3b..b651eb9c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1180,7 +1180,7 @@ def create_ui(wrap_gradio_gpu_call): } browser_interface = images_history.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) - inspiration_interface = inspiration.ui(gr, opts) + inspiration_interface = inspiration.ui(gr, opts, txt2img_prompt, img2img_prompt) with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): diff --git a/webui.py b/webui.py index 5923905f..5ccae715 100644 --- a/webui.py +++ b/webui.py @@ -72,6 +72,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) def initialize(): + modules.scripts.load_scripts(os.path.join(script_path, "scripts")) if cmd_opts.ui_debug_mode: class enmpty(): name = None @@ -84,7 +85,7 @@ def initialize(): shared.face_restorers.append(modules.face_restoration.FaceRestoration()) modelloader.load_upscalers() - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) + shared.sd_model = modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) -- cgit v1.2.3 From 9e40520f00d836cfa93187f7f1e81e2a7bd100b9 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Fri, 21 Oct 2022 15:13:12 -0500 Subject: refactor internal terminology to use 'clear' instead of 'trash' like #2728 --- javascript/ui.js | 2 +- modules/shared.py | 2 +- modules/ui.py | 22 +++++++++++----------- 3 files changed, 13 insertions(+), 13 deletions(-) (limited to 'modules/shared.py') diff --git a/javascript/ui.js b/javascript/ui.js index acd57565..45d93a5c 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -162,7 +162,7 @@ function selected_tab_id() { } -function trash_prompt(_, confirmed,_steps) { +function clear_prompt(_, confirmed,_steps) { if(confirm("Delete prompt?")) { confirmed = true diff --git a/modules/shared.py b/modules/shared.py index 1585d532..ab5a0e9a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -317,7 +317,7 @@ options_templates.update(options_section(('ui', "User interface"), { "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), - "trash_prompt_visible": OptionInfo(True, "Show trash prompt button"), + "clear_prompt_visible": OptionInfo(True, "Show clear prompt button"), 'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"), 'localization': OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), })) diff --git a/modules/ui.py b/modules/ui.py index d3a89bf7..31150800 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -88,7 +88,7 @@ folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 -trash_prompt_symbol = '\U0001F5D1' # +clear_prompt_symbol = '\U0001F5D1' # 🗑️ def plaintext_to_html(text): @@ -430,14 +430,14 @@ def create_seed_inputs(): -def connect_trash_prompt(_prompt, confirmed, _token_counter): +def clear_prompt(_prompt, confirmed, _token_counter): if(confirmed): return ["", confirmed, update_token_counter("", 1)] -def trash_prompt_click(button, prompt, _dummy_confirmed, token_counter): +def connect_clear_prompt(button, prompt, _dummy_confirmed, token_counter): button.click( - _js="trash_prompt", - fn=connect_trash_prompt, + _js="clear_prompt", + fn=clear_prompt, inputs=[prompt, _dummy_confirmed, token_counter], outputs=[prompt, _dummy_confirmed, token_counter], ) @@ -518,7 +518,7 @@ def create_toprow(is_img2img): paste = gr.Button(value=paste_symbol, elem_id="paste") save_style = gr.Button(value=save_style_symbol, elem_id="style_create") prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - trash_prompt = gr.Button(value=trash_prompt_symbol, elem_id="trash_prompt", visible=opts.trash_prompt_visible) + clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id="clear_prompt", visible=opts.clear_prompt_visible) token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") @@ -559,7 +559,7 @@ def create_toprow(is_img2img): prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) prompt_style2.save_to_config = True - return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, trash_prompt + return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, clear_prompt_button def setup_progressbar(progressbar, preview, id_part, textinfo=None): @@ -640,7 +640,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,\ txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter,\ - token_button, trash_prompt_button = create_toprow(is_img2img=False) + token_button, clear_prompt_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -716,7 +716,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - trash_prompt_click(trash_prompt_button, txt2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, txt2img_prompt, dummy_component, token_counter) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img), @@ -853,7 +853,7 @@ def create_ui(wrap_gradio_gpu_call): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit,\ img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,\ - token_counter, token_button, trash_prompt_button = create_toprow(is_img2img=True) + token_counter, token_button, clear_prompt_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): @@ -954,7 +954,7 @@ def create_ui(wrap_gradio_gpu_call): connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - trash_prompt_click(trash_prompt_button, img2img_prompt, dummy_component, token_counter) + connect_clear_prompt(clear_prompt_button, img2img_prompt, dummy_component, token_counter) img2img_prompt_img.change( fn=modules.images.image_data, -- cgit v1.2.3 From 29bfacd63cb5c73b9643d94f255cca818fd49d9c Mon Sep 17 00:00:00 2001 From: Extraltodeus Date: Sat, 22 Oct 2022 00:12:46 +0200 Subject: implement CUDA device selection, --device-id arg --- modules/shared.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 41d7f08e..03032a47 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -80,6 +80,7 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") +parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) cmd_opts = parser.parse_args() restricted_opts = [ -- cgit v1.2.3 From 2b91251637078e04472c91a06a8d9c4db9c1dcf0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 12:23:45 +0300 Subject: removed aesthetic gradients as built-in added support for extensions --- .gitignore | 2 +- extensions/put extension here.txt | 0 modules/aesthetic_clip.py | 241 -------------------------------------- modules/images_history.py | 2 +- modules/img2img.py | 5 +- modules/processing.py | 35 ++++-- modules/script_callbacks.py | 42 +++++++ modules/scripts.py | 210 ++++++++++++++++++++++++--------- modules/sd_hijack.py | 1 - modules/sd_models.py | 7 +- modules/shared.py | 19 --- modules/txt2img.py | 5 +- modules/ui.py | 83 ++----------- webui.py | 7 +- 14 files changed, 249 insertions(+), 410 deletions(-) create mode 100644 extensions/put extension here.txt delete mode 100644 modules/aesthetic_clip.py create mode 100644 modules/script_callbacks.py (limited to 'modules/shared.py') diff --git a/.gitignore b/.gitignore index f9c3357c..2f1e08ed 100644 --- a/.gitignore +++ b/.gitignore @@ -27,4 +27,4 @@ __pycache__ notification.mp3 /SwinIR /textual_inversion -.vscode \ No newline at end of file +.vscode diff --git a/extensions/put extension here.txt b/extensions/put extension here.txt new file mode 100644 index 00000000..e69de29b diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py deleted file mode 100644 index 8c828541..00000000 --- a/modules/aesthetic_clip.py +++ /dev/null @@ -1,241 +0,0 @@ -import copy -import itertools -import os -from pathlib import Path -import html -import gc - -import gradio as gr -import torch -from PIL import Image -from torch import optim - -from modules import shared -from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer -from tqdm.auto import tqdm, trange -from modules.shared import opts, device - - -def get_all_images_in_folder(folder): - return [os.path.join(folder, f) for f in os.listdir(folder) if - os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)] - - -def check_is_valid_image_file(filename): - return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp")) - - -def batched(dataset, total, n=1): - for ndx in range(0, total, n): - yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))] - - -def iter_to_batched(iterable, n=1): - it = iter(iterable) - while True: - chunk = tuple(itertools.islice(it, n)) - if not chunk: - return - yield chunk - - -def create_ui(): - import modules.ui - - with gr.Group(): - with gr.Accordion("Open for Clip Aesthetic!", open=False): - with gr.Row(): - aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", - value=0.9) - aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) - - with gr.Row(): - aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', - placeholder="Aesthetic learning rate", value="0.0001") - aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) - aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()), - label="Aesthetic imgs embedding", - value="None") - - modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings") - - with gr.Row(): - aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', - placeholder="This text is used to rotate the feature space of the imgs embs", - value="") - aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01, - value=0.1) - aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) - - return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative - - -aesthetic_clip_model = None - - -def aesthetic_clip(): - global aesthetic_clip_model - - if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path: - aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path) - aesthetic_clip_model.cpu() - - return aesthetic_clip_model - - -def generate_imgs_embd(name, folder, batch_size): - model = aesthetic_clip().to(device) - processor = CLIPProcessor.from_pretrained(model.name_or_path) - - with torch.no_grad(): - embs = [] - for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size), - desc=f"Generating embeddings for {name}"): - if shared.state.interrupted: - break - inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device) - outputs = model.get_image_features(**inputs).cpu() - embs.append(torch.clone(outputs)) - inputs.to("cpu") - del inputs, outputs - - embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True) - - # The generated embedding will be located here - path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") - torch.save(embs, path) - - model.cpu() - del processor - del embs - gc.collect() - torch.cuda.empty_cache() - res = f""" - Done generating embedding for {name}! - Aesthetic embedding saved to {html.escape(path)} - """ - shared.update_aesthetic_embeddings() - return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", - value="None"), \ - gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), - label="Imgs embedding", - value="None"), res, "" - - -def slerp(low, high, val): - low_norm = low / torch.norm(low, dim=1, keepdim=True) - high_norm = high / torch.norm(high, dim=1, keepdim=True) - omega = torch.acos((low_norm * high_norm).sum(1)) - so = torch.sin(omega) - res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high - return res - - -class AestheticCLIP: - def __init__(self): - self.skip = False - self.aesthetic_steps = 0 - self.aesthetic_weight = 0 - self.aesthetic_lr = 0 - self.slerp = False - self.aesthetic_text_negative = "" - self.aesthetic_slerp_angle = 0 - self.aesthetic_imgs_text = "" - - self.image_embs_name = None - self.image_embs = None - self.load_image_embs(None) - - def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, - aesthetic_slerp=True, aesthetic_imgs_text="", - aesthetic_slerp_angle=0.15, - aesthetic_text_negative=False): - self.aesthetic_imgs_text = aesthetic_imgs_text - self.aesthetic_slerp_angle = aesthetic_slerp_angle - self.aesthetic_text_negative = aesthetic_text_negative - self.slerp = aesthetic_slerp - self.aesthetic_lr = aesthetic_lr - self.aesthetic_weight = aesthetic_weight - self.aesthetic_steps = aesthetic_steps - self.load_image_embs(image_embs_name) - - if self.image_embs_name is not None: - p.extra_generation_params.update({ - "Aesthetic LR": aesthetic_lr, - "Aesthetic weight": aesthetic_weight, - "Aesthetic steps": aesthetic_steps, - "Aesthetic embedding": self.image_embs_name, - "Aesthetic slerp": aesthetic_slerp, - "Aesthetic text": aesthetic_imgs_text, - "Aesthetic text negative": aesthetic_text_negative, - "Aesthetic slerp angle": aesthetic_slerp_angle, - }) - - def set_skip(self, skip): - self.skip = skip - - def load_image_embs(self, image_embs_name): - if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None": - image_embs_name = None - self.image_embs_name = None - if image_embs_name is not None and self.image_embs_name != image_embs_name: - self.image_embs_name = image_embs_name - self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device) - self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True) - self.image_embs.requires_grad_(False) - - def __call__(self, z, remade_batch_tokens): - if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None: - tokenizer = shared.sd_model.cond_stage_model.tokenizer - if not opts.use_old_emphasis_implementation: - remade_batch_tokens = [ - [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in - remade_batch_tokens] - - tokens = torch.asarray(remade_batch_tokens).to(device) - - model = copy.deepcopy(aesthetic_clip()).to(device) - model.requires_grad_(True) - if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: - text_embs_2 = model.get_text_features( - **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device)) - if self.aesthetic_text_negative: - text_embs_2 = self.image_embs - text_embs_2 - text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True) - img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle) - else: - img_embs = self.image_embs - - with torch.enable_grad(): - - # We optimize the model to maximize the similarity - optimizer = optim.Adam( - model.text_model.parameters(), lr=self.aesthetic_lr - ) - - for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"): - text_embs = model.get_text_features(input_ids=tokens) - text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True) - sim = text_embs @ img_embs.T - loss = -sim - optimizer.zero_grad() - loss.mean().backward() - optimizer.step() - - zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) - if opts.CLIP_stop_at_last_layers > 1: - zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers] - zn = model.text_model.final_layer_norm(zn) - else: - zn = zn.last_hidden_state - model.cpu() - del model - gc.collect() - torch.cuda.empty_cache() - zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1) - if self.slerp: - z = slerp(z, zn, self.aesthetic_weight) - else: - z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight - - return z diff --git a/modules/images_history.py b/modules/images_history.py index 78fd0543..bc5cf11f 100644 --- a/modules/images_history.py +++ b/modules/images_history.py @@ -310,7 +310,7 @@ def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): forward = gr.Button('Prev batch') backward = gr.Button('Next batch') with gr.Column(scale=3): - load_info = gr.HTML(visible=not custom_dir) + load_info = gr.HTML(visible=not custom_dir) with gr.Row(visible=False) as warning: warning_box = gr.Textbox("Message", interactive=False) diff --git a/modules/img2img.py b/modules/img2img.py index eea5199b..8d9f7cf9 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_inpaint = mode == 1 is_batch = mode == 2 @@ -109,7 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) - shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index ff1ec4c9..372489f7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -104,6 +104,12 @@ class StableDiffusionProcessing(): self.seed_resize_from_h = 0 self.seed_resize_from_w = 0 + self.scripts = None + self.script_args = None + self.all_prompts = None + self.all_seeds = None + self.all_subseeds = None + def init(self, all_prompts, all_seeds, all_subseeds): pass @@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed: shared.prompt_styles.apply_styles(p) if type(p.prompt) == list: - all_prompts = p.prompt + p.all_prompts = p.prompt else: - all_prompts = p.batch_size * p.n_iter * [p.prompt] + p.all_prompts = p.batch_size * p.n_iter * [p.prompt] if type(seed) == list: - all_seeds = seed + p.all_seeds = seed else: - all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] + p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))] if type(subseed) == list: - all_subseeds = subseed + p.all_subseeds = subseed else: - all_subseeds = [int(subseed) + x for x in range(len(all_prompts))] + p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] def infotext(iteration=0, position_in_batch=0): - return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) + return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch) if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() + if p.scripts is not None: + p.scripts.run_alwayson_scripts(p) + infotexts = [] output_images = [] with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): - p.init(all_prompts, all_seeds, all_subseeds) + p.init(p.all_prompts, p.all_seeds, p.all_subseeds) if state.job_count == -1: state.job_count = p.n_iter @@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if state.interrupted: break - prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size] - seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size] - subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] + seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] + subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] if (len(prompts) == 0): break @@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: index_of_first_image = 1 if opts.grid_save: - images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) devices.torch_gc() - return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) + return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py new file mode 100644 index 00000000..866b7acd --- /dev/null +++ b/modules/script_callbacks.py @@ -0,0 +1,42 @@ + +callbacks_model_loaded = [] +callbacks_ui_tabs = [] + + +def clear_callbacks(): + callbacks_model_loaded.clear() + callbacks_ui_tabs.clear() + + +def model_loaded_callback(sd_model): + for callback in callbacks_model_loaded: + callback(sd_model) + + +def ui_tabs_callback(): + res = [] + + for callback in callbacks_ui_tabs: + res += callback() or [] + + return res + + +def on_model_loaded(callback): + """register a function to be called when the stable diffusion model is created; the model is + passed as an argument""" + callbacks_model_loaded.append(callback) + + +def on_ui_tabs(callback): + """register a function to be called when the UI is creating new tabs. + The function must either return a None, which means no new tabs to be added, or a list, where + each element is a tuple: + (gradio_component, title, elem_id) + + gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks) + title is tab text displayed to user in the UI + elem_id is HTML id for the tab + """ + callbacks_ui_tabs.append(callback) + diff --git a/modules/scripts.py b/modules/scripts.py index 1039fa9c..65f25f49 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,86 +1,153 @@ import os import sys import traceback +from collections import namedtuple import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared +from modules import shared, paths, script_callbacks + +AlwaysVisible = object() + class Script: filename = None args_from = None args_to = None + alwayson = False + + infotext_fields = None + """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when + parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example + """ - # The title of the script. This is what will be displayed in the dropdown menu. def title(self): + """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" + raise NotImplementedError() - # How the script is displayed in the UI. See https://gradio.app/docs/#components - # for the different UI components you can use and how to create them. - # Most UI components can return a value, such as a boolean for a checkbox. - # The returned values are passed to the run method as parameters. def ui(self, is_img2img): + """this function should create gradio UI elements. See https://gradio.app/docs/#components + The return value should be an array of all components that are used in processing. + Values of those returned componenbts will be passed to run() and process() functions. + """ + pass - # Determines when the script should be shown in the dropdown menu via the - # returned value. As an example: - # is_img2img is True if the current tab is img2img, and False if it is txt2img. - # Thus, return is_img2img to only show the script on the img2img tab. def show(self, is_img2img): + """ + is_img2img is True if this function is called for the img2img interface, and Fasle otherwise + + This function should return: + - False if the script should not be shown in UI at all + - True if the script should be shown in UI if it's scelected in the scripts drowpdown + - script.AlwaysVisible if the script should be shown in UI at all times + """ + return True - # This is where the additional processing is implemented. The parameters include - # self, the model object "p" (a StableDiffusionProcessing class, see - # processing.py), and the parameters returned by the ui method. - # Custom functions can be defined here, and additional libraries can be imported - # to be used in processing. The return value should be a Processed object, which is - # what is returned by the process_images method. - def run(self, *args): + def run(self, p, *args): + """ + This function is called if the script has been selected in the script dropdown. + It must do all processing and return the Processed object with results, same as + one returned by processing.process_images. + + Usually the processing is done by calling the processing.process_images function. + + args contains all values returned by components from ui() + """ + raise NotImplementedError() - # The description method is currently unused. - # To add a description that appears when hovering over the title, amend the "titles" - # dict in script.js to include the script title (returned by title) as a key, and - # your description as the value. + def process(self, p, *args): + """ + This function is called before processing begins for AlwaysVisible scripts. + scripts. You can modify the processing object (p) here, inject hooks, etc. + """ + + pass + def describe(self): + """unused""" return "" +current_basedir = paths.script_path + + +def basedir(): + """returns the base directory for the current script. For scripts in the main scripts directory, + this is the main directory (where webui.py resides), and for scripts in extensions directory + (ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic) + """ + return current_basedir + + scripts_data = [] +ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"]) +ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"]) + + +def list_scripts(scriptdirname, extension): + scripts_list = [] + + basedir = os.path.join(paths.script_path, scriptdirname) + if os.path.exists(basedir): + for filename in sorted(os.listdir(basedir)): + scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) + + extdir = os.path.join(paths.script_path, "extensions") + if os.path.exists(extdir): + for dirname in sorted(os.listdir(extdir)): + dirpath = os.path.join(extdir, dirname) + if not os.path.isdir(dirpath): + continue + for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))): + scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename))) -def load_scripts(basedir): - if not os.path.exists(basedir): - return + scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] - for filename in sorted(os.listdir(basedir)): - path = os.path.join(basedir, filename) + return scripts_list - if os.path.splitext(path)[1].lower() != '.py': - continue - if not os.path.isfile(path): - continue +def load_scripts(): + global current_basedir + scripts_data.clear() + script_callbacks.clear_callbacks() + + scripts_list = list_scripts("scripts", ".py") + + syspath = sys.path + for scriptfile in sorted(scripts_list): try: - with open(path, "r", encoding="utf8") as file: + if scriptfile.basedir != paths.script_path: + sys.path = [scriptfile.basedir] + sys.path + current_basedir = scriptfile.basedir + + with open(scriptfile.path, "r", encoding="utf8") as file: text = file.read() from types import ModuleType - compiled = compile(text, path, 'exec') - module = ModuleType(filename) + compiled = compile(text, scriptfile.path, 'exec') + module = ModuleType(scriptfile.filename) exec(compiled, module.__dict__) for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): - scripts_data.append((script_class, path)) + scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir)) except Exception: - print(f"Error loading script: {filename}", file=sys.stderr) + print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + finally: + sys.path = syspath + current_basedir = paths.script_path + def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: @@ -96,56 +163,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): class ScriptRunner: def __init__(self): self.scripts = [] + self.selectable_scripts = [] + self.alwayson_scripts = [] self.titles = [] + self.infotext_fields = [] def setup_ui(self, is_img2img): - for script_class, path in scripts_data: + for script_class, path, basedir in scripts_data: script = script_class() script.filename = path - if not script.show(is_img2img): - continue + visibility = script.show(is_img2img) - self.scripts.append(script) + if visibility == AlwaysVisible: + self.scripts.append(script) + self.alwayson_scripts.append(script) + script.alwayson = True - self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] + elif visibility: + self.scripts.append(script) + self.selectable_scripts.append(script) - dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") - dropdown.save_to_config = True - inputs = [dropdown] + self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] + + inputs = [None] + inputs_alwayson = [True] - for script in self.scripts: + def create_script_ui(script, inputs, inputs_alwayson): script.args_from = len(inputs) script.args_to = len(inputs) controls = wrap_call(script.ui, script.filename, "ui", is_img2img) if controls is None: - continue + return for control in controls: control.custom_script_source = os.path.basename(script.filename) - control.visible = False + if not script.alwayson: + control.visible = False + + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields inputs += controls + inputs_alwayson += [script.alwayson for _ in controls] script.args_to = len(inputs) + for script in self.alwayson_scripts: + with gr.Group(): + create_script_ui(script, inputs, inputs_alwayson) + + dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") + dropdown.save_to_config = True + inputs[0] = dropdown + + for script in self.selectable_scripts: + create_script_ui(script, inputs, inputs_alwayson) + def select_script(script_index): - if 0 < script_index <= len(self.scripts): - script = self.scripts[script_index-1] + if 0 < script_index <= len(self.selectable_scripts): + script = self.selectable_scripts[script_index-1] args_from = script.args_from args_to = script.args_to else: args_from = 0 args_to = 0 - return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] + return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)] def init_field(title): if title == 'None': return script_index = self.titles.index(title) - script = self.scripts[script_index] + script = self.selectable_scripts[script_index] for i in range(script.args_from, script.args_to): inputs[i].visible = True @@ -164,7 +255,7 @@ class ScriptRunner: if script_index == 0: return None - script = self.scripts[script_index-1] + script = self.selectable_scripts[script_index-1] if script is None: return None @@ -176,6 +267,15 @@ class ScriptRunner: return processed + def run_alwayson_scripts(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.process(p, *script_args) + except Exception: + print(f"Error running alwayson script: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def reload_sources(self): for si, script in list(enumerate(self.scripts)): with open(script.filename, "r", encoding="utf8") as file: @@ -197,19 +297,21 @@ class ScriptRunner: self.scripts[si].args_from = args_from self.scripts[si].args_to = args_to + scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + def reload_script_body_only(): scripts_txt2img.reload_sources() scripts_img2img.reload_sources() -def reload_scripts(basedir): +def reload_scripts(): global scripts_txt2img, scripts_img2img - scripts_data.clear() - load_scripts(basedir) + load_scripts() scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 1f8587d1..0f10828e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -332,7 +332,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers.append([1.0] * 75) z1 = self.process_tokens(tokens, multipliers) - z1 = shared.aesthetic_clip(z1, remade_batch_tokens) z = z1 if z is None else torch.cat((z, z1), axis=-2) remade_batch_tokens = rem_tokens diff --git a/modules/sd_models.py b/modules/sd_models.py index d99dbce8..f9b3063d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -7,7 +7,7 @@ from omegaconf import OmegaConf from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices +from modules import shared, modelloader, devices, script_callbacks from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -238,6 +238,9 @@ def load_model(checkpoint_info=None): sd_hijack.model_hijack.hijack(sd_model) sd_model.eval() + shared.sd_model = sd_model + + script_callbacks.model_loaded_callback(sd_model) print(f"Model loaded.") return sd_model @@ -252,7 +255,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model(checkpoint_info) + load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/modules/shared.py b/modules/shared.py index 0dbe360d..7d786f07 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -31,7 +31,6 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") -parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") @@ -109,21 +108,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) loaded_hypernetwork = None - -os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True) -aesthetic_embeddings = {} - - -def update_aesthetic_embeddings(): - global aesthetic_embeddings - aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in - os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} - aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings) - - -update_aesthetic_embeddings() - - def reload_hypernetworks(): global hypernetworks @@ -415,9 +399,6 @@ sd_model = None clip_model = None -from modules.aesthetic_clip import AestheticCLIP -aesthetic_clip = AestheticCLIP() - progress_print_out = sys.stdout diff --git a/modules/txt2img.py b/modules/txt2img.py index 1761cfa2..c9d5a090 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -7,7 +7,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -36,7 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) - shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) + p.scripts = modules.scripts.scripts_txt2img + p.script_args = args if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/ui.py b/modules/ui.py index 70a9cf10..c977482c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,10 +23,10 @@ import gradio as gr import gradio.utils import gradio.routes -from modules import sd_hijack, sd_models, localization +from modules import sd_hijack, sd_models, localization, script_callbacks from modules.paths import script_path -from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings +from modules.shared import opts, cmd_opts, restricted_opts if cmd_opts.deepdanbooru: from modules.deepbooru import get_deepbooru_tags @@ -44,7 +44,6 @@ from modules.images import save_image import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.aesthetic_clip as aesthetic_clip import modules.images_history as img_his @@ -662,8 +661,6 @@ def create_ui(wrap_gradio_gpu_call): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui() - with gr.Group(): custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) @@ -718,14 +715,6 @@ def create_ui(wrap_gradio_gpu_call): denoising_strength, firstphase_width, firstphase_height, - aesthetic_lr, - aesthetic_weight, - aesthetic_steps, - aesthetic_imgs, - aesthetic_slerp, - aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative ] + custom_inputs, outputs=[ @@ -804,14 +793,7 @@ def create_ui(wrap_gradio_gpu_call): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), - (aesthetic_lr, "Aesthetic LR"), - (aesthetic_weight, "Aesthetic weight"), - (aesthetic_steps, "Aesthetic steps"), - (aesthetic_imgs, "Aesthetic embedding"), - (aesthetic_slerp, "Aesthetic slerp"), - (aesthetic_imgs_text, "Aesthetic text"), - (aesthetic_text_negative, "Aesthetic text negative"), - (aesthetic_slerp_angle, "Aesthetic slerp angle"), + *modules.scripts.scripts_txt2img.infotext_fields ] txt2img_preview_params = [ @@ -896,8 +878,6 @@ def create_ui(wrap_gradio_gpu_call): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() - aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui() - with gr.Group(): custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) @@ -988,14 +968,6 @@ def create_ui(wrap_gradio_gpu_call): inpainting_mask_invert, img2img_batch_input_dir, img2img_batch_output_dir, - aesthetic_lr_im, - aesthetic_weight_im, - aesthetic_steps_im, - aesthetic_imgs_im, - aesthetic_slerp_im, - aesthetic_imgs_text_im, - aesthetic_slerp_angle_im, - aesthetic_text_negative_im, ] + custom_inputs, outputs=[ img2img_gallery, @@ -1087,14 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), (denoising_strength, "Denoising strength"), - (aesthetic_lr_im, "Aesthetic LR"), - (aesthetic_weight_im, "Aesthetic weight"), - (aesthetic_steps_im, "Aesthetic steps"), - (aesthetic_imgs_im, "Aesthetic embedding"), - (aesthetic_slerp_im, "Aesthetic slerp"), - (aesthetic_imgs_text_im, "Aesthetic text"), - (aesthetic_text_negative_im, "Aesthetic text negative"), - (aesthetic_slerp_angle_im, "Aesthetic slerp angle"), + *modules.scripts.scripts_img2img.infotext_fields ] token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) @@ -1217,9 +1182,9 @@ def create_ui(wrap_gradio_gpu_call): ) #images history images_history_switch_dict = { - "fn":modules.generation_parameters_copypaste.connect_paste, - "t2i":txt2img_paste_fields, - "i2i":img2img_paste_fields + "fn": modules.generation_parameters_copypaste.connect_paste, + "t2i": txt2img_paste_fields, + "i2i": img2img_paste_fields } images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) @@ -1264,18 +1229,6 @@ def create_ui(wrap_gradio_gpu_call): with gr.Column(): create_embedding = gr.Button(value="Create embedding", variant='primary') - with gr.Tab(label="Create aesthetic images embedding"): - - new_embedding_name_ae = gr.Textbox(label="Name") - process_src_ae = gr.Textbox(label='Source directory') - batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256) - with gr.Row(): - with gr.Column(scale=3): - gr.HTML(value="") - - with gr.Column(): - create_embedding_ae = gr.Button(value="Create images embedding", variant='primary') - with gr.Tab(label="Create hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) @@ -1375,21 +1328,6 @@ def create_ui(wrap_gradio_gpu_call): ] ) - create_embedding_ae.click( - fn=aesthetic_clip.generate_imgs_embd, - inputs=[ - new_embedding_name_ae, - process_src_ae, - batch_ae - ], - outputs=[ - aesthetic_imgs, - aesthetic_imgs_im, - ti_output, - ti_outcome, - ] - ) - create_hypernetwork.click( fn=modules.hypernetworks.ui.create_hypernetwork, inputs=[ @@ -1580,10 +1518,10 @@ Requested path was: {f} if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() + oldval = opts.data.get(key, None) if cmd_opts.hide_ui_dir_config and key in restricted_opts: return gr.update(value=oldval), opts.dumpjson() - oldval = opts.data.get(key, None) opts.data[key] = value if oldval != value: @@ -1692,9 +1630,12 @@ Requested path was: {f} (images_history, "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), - (settings_interface, "Settings", "settings"), ] + interfaces += script_callbacks.ui_tabs_callback() + + interfaces += [(settings_interface, "Settings", "settings")] + with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: css = file.read() diff --git a/webui.py b/webui.py index 87589064..b1deca1b 100644 --- a/webui.py +++ b/webui.py @@ -71,6 +71,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) + def initialize(): modelloader.cleanup_models() modules.sd_models.setup_model() @@ -79,9 +80,9 @@ def initialize(): shared.face_restorers.append(modules.face_restoration.FaceRestoration()) modelloader.load_upscalers() - modules.scripts.load_scripts(os.path.join(script_path, "scripts")) + modules.scripts.load_scripts() - shared.sd_model = modules.sd_models.load_model() + modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) @@ -145,7 +146,7 @@ def webui(): sd_samplers.set_samplers() print('Reloading Custom Scripts') - modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) + modules.scripts.reload_scripts() print('Reloading modules: modules.ui') importlib.reload(modules.ui) print('Refreshing Model List') -- cgit v1.2.3 From d37cfffd537cd29309afbcb192c4f979995c6a34 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 19:18:56 +0300 Subject: added callback for creating new settings in extensions --- modules/script_callbacks.py | 11 +++++++++++ modules/shared.py | 19 +++++++++++++++++-- modules/ui.py | 6 +++++- 3 files changed, 33 insertions(+), 3 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 866b7acd..1270e50f 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,6 +1,7 @@ callbacks_model_loaded = [] callbacks_ui_tabs = [] +callbacks_ui_settings = [] def clear_callbacks(): @@ -22,6 +23,11 @@ def ui_tabs_callback(): return res +def ui_settings_callback(): + for callback in callbacks_ui_settings: + callback() + + def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument""" @@ -40,3 +46,8 @@ def on_ui_tabs(callback): """ callbacks_ui_tabs.append(callback) + +def on_ui_settings(callback): + """register a function to be called before UI settingsare populated; add your settings + by using shared.opts.add_option(shared.OptionInfo(...)) """ + callbacks_ui_settings.append(callback) diff --git a/modules/shared.py b/modules/shared.py index 5d83971e..d9cb65ef 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -165,13 +165,13 @@ def realesrgan_models_names(): class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None): + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None): self.default = default self.label = label self.component = component self.component_args = component_args self.onchange = onchange - self.section = None + self.section = section self.refresh = refresh @@ -327,6 +327,7 @@ options_templates.update(options_section(('images-history', "Images Browser"), { })) + class Options: data = None data_labels = options_templates @@ -389,6 +390,20 @@ class Options: d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} return json.dumps(d) + def add_option(self, key, info): + self.data_labels[key] = info + + def reorder(self): + """reorder settings so that all items related to section always go together""" + + section_ids = {} + settings_items = self.data_labels.items() + for k, item in settings_items: + if item.section not in section_ids: + section_ids[item.section] = len(section_ids) + + self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])} + opts = Options() if os.path.exists(config_filename): diff --git a/modules/ui.py b/modules/ui.py index d8d52db1..2849b111 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1461,6 +1461,9 @@ def create_ui(wrap_gradio_gpu_call): components = [] component_dict = {} + script_callbacks.ui_settings_callback() + opts.reorder() + def open_folder(f): if not os.path.exists(f): print(f'Folder "{f}" does not exist. After you create an image, the folder will be created.') @@ -1564,7 +1567,8 @@ Requested path was: {f} previous_section = item.section - gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='

{}

'.format(item.section[1])) + elem_id, text = item.section + gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) if k in quicksettings_names: quicksettings_list.append((i, k, item)) -- cgit v1.2.3 From 4fdb53c1e9962507fc8336dad9a0fabfe6c418c0 Mon Sep 17 00:00:00 2001 From: Unnoen Date: Wed, 19 Oct 2022 21:38:10 +1100 Subject: Generate grid preview for progress image --- modules/sd_samplers.py | 26 +++++++++++++++++++++++++- modules/shared.py | 1 + modules/ui.py | 5 ++++- 3 files changed, 30 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index f58a29b9..74a480e5 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,7 +7,7 @@ import inspect import k_diffusion.sampling import ldm.models.diffusion.ddim import ldm.models.diffusion.plms -from modules import prompt_parser, devices, processing +from modules import prompt_parser, devices, processing, images from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -89,6 +89,30 @@ def sample_to_image(samples): x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample) +def samples_to_image_grid(samples): + progress_images = [] + for i in range(len(samples)): + # Decode the samples individually to reduce VRAM usage at the cost of a bit of speed. + x_sample = processing.decode_first_stage(shared.sd_model, samples[i:i+1])[0] + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + progress_images.append(Image.fromarray(x_sample)) + + return images.image_grid(progress_images) + +def samples_to_image_grid_combined(samples): + progress_images = [] + # Decode all samples at once to increase speed at the cost of VRAM usage. + x_samples = processing.decode_first_stage(shared.sd_model, samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + for x_sample in x_samples: + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + progress_images.append(Image.fromarray(x_sample)) + + return images.image_grid(progress_images) def store_latent(decoded): state.current_latent = decoded diff --git a/modules/shared.py b/modules/shared.py index d9cb65ef..95d6e225 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -294,6 +294,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), + "progress_decode_combined": OptionInfo(False, "Decode all progress images at once. (Slighty speeds up progress generation but consumes significantly more VRAM with large batches.)"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), diff --git a/modules/ui.py b/modules/ui.py index 56c233ab..de0abc7e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -318,7 +318,10 @@ def check_progress_call(id_part): if shared.parallel_processing_allowed: if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None: - shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent) + if opts.progress_decode_combined: + shared.state.current_image = modules.sd_samplers.samples_to_image_grid_combined(shared.state.current_latent) + else: + shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent) shared.state.current_image_sampling_step = shared.state.sampling_step image = shared.state.current_image -- cgit v1.2.3 From d213d6ca6f90094cb45c11e2f3cb37d25a8d1f94 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 20:48:13 +0300 Subject: removed the option to use 2x more memory when generating previews added an option to always only show one image in previews removed duplicate code --- modules/sd_samplers.py | 35 ++++++++++------------------------- modules/shared.py | 2 +- modules/ui.py | 6 +++--- 3 files changed, 14 insertions(+), 29 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 74a480e5..0b408a70 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -71,6 +71,7 @@ sampler_extra_params = { 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], } + def setup_img2img_steps(p, steps=None): if opts.img2img_fix_steps or steps is not None: steps = int((steps or p.steps) / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 @@ -82,37 +83,21 @@ def setup_img2img_steps(p, steps=None): return steps, t_enc -def sample_to_image(samples): - x_sample = processing.decode_first_stage(shared.sd_model, samples[0:1])[0] +def single_sample_to_image(sample): + x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) return Image.fromarray(x_sample) + +def sample_to_image(samples): + return single_sample_to_image(samples[0]) + + def samples_to_image_grid(samples): - progress_images = [] - for i in range(len(samples)): - # Decode the samples individually to reduce VRAM usage at the cost of a bit of speed. - x_sample = processing.decode_first_stage(shared.sd_model, samples[i:i+1])[0] - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - progress_images.append(Image.fromarray(x_sample)) - - return images.image_grid(progress_images) - -def samples_to_image_grid_combined(samples): - progress_images = [] - # Decode all samples at once to increase speed at the cost of VRAM usage. - x_samples = processing.decode_first_stage(shared.sd_model, samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - progress_images.append(Image.fromarray(x_sample)) - - return images.image_grid(progress_images) + return images.image_grid([single_sample_to_image(sample) for sample in samples]) + def store_latent(decoded): state.current_latent = decoded diff --git a/modules/shared.py b/modules/shared.py index 95d6e225..25bfc895 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -294,7 +294,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), options_templates.update(options_section(('ui', "User interface"), { "show_progressbar": OptionInfo(True, "Show progressbar"), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), - "progress_decode_combined": OptionInfo(False, "Decode all progress images at once. (Slighty speeds up progress generation but consumes significantly more VRAM with large batches.)"), + "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), diff --git a/modules/ui.py b/modules/ui.py index de0abc7e..ffa14cac 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -318,10 +318,10 @@ def check_progress_call(id_part): if shared.parallel_processing_allowed: if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None: - if opts.progress_decode_combined: - shared.state.current_image = modules.sd_samplers.samples_to_image_grid_combined(shared.state.current_latent) - else: + if opts.show_progress_grid: shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent) + else: + shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent) shared.state.current_image_sampling_step = shared.state.sampling_step image = shared.state.current_image -- cgit v1.2.3 From be748e8b086bd9834d08bdd9160649a5e7700af7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 22:05:22 +0300 Subject: add --freeze-settings commandline argument to disable changing settings --- modules/shared.py | 1 + modules/ui.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 25bfc895..b55371d3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -64,6 +64,7 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) +parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) diff --git a/modules/ui.py b/modules/ui.py index ffa14cac..2311572c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -580,6 +580,9 @@ def apply_setting(key, value): if value is None: return gr.update() + if shared.cmd_opts.freeze_settings: + return gr.update() + # dont allow model to be swapped when model hash exists in prompt if key == "sd_model_checkpoint" and opts.disable_weights_auto_swap: return gr.update() @@ -1501,6 +1504,8 @@ Requested path was: {f} def run_settings(*args): changed = 0 + assert not shared.cmd_opts.freeze_settings, "changing settings is disabled" + for key, value, comp in zip(opts.data_labels.keys(), args, components): if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default): return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson() @@ -1530,6 +1535,8 @@ Requested path was: {f} return f'{changed} settings changed.', opts.dumpjson() def run_settings_single(value, key): + assert not shared.cmd_opts.freeze_settings, "changing settings is disabled" + if not opts.same_type(value, opts.data_labels[key].default): return gr.update(visible=True), opts.dumpjson() @@ -1582,7 +1589,7 @@ Requested path was: {f} elem_id, text = item.section gr.HTML(elem_id="settings_header_text_{}".format(elem_id), value='

{}

'.format(text)) - if k in quicksettings_names: + if k in quicksettings_names and not shared.cmd_opts.freeze_settings: quicksettings_list.append((i, k, item)) components.append(dummy_component) else: @@ -1615,7 +1622,7 @@ Requested path was: {f} def reload_scripts(): modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page + reload_javascript() # need to refresh the html page reload_script_bodies.click( fn=reload_scripts, -- cgit v1.2.3 From 124e44cf1eed1edc68954f63a2a9bc428aabbcec Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 24 Oct 2022 09:51:56 +0800 Subject: remove browser to extension --- .gitignore | 1 - javascript/images_history.js | 200 -------------------- javascript/inspiration.js | 48 ----- modules/images_history.py | 424 ------------------------------------------- modules/inspiration.py | 193 -------------------- modules/script_callbacks.py | 2 - modules/shared.py | 15 -- modules/ui.py | 20 +- 8 files changed, 4 insertions(+), 899 deletions(-) delete mode 100644 javascript/images_history.js delete mode 100644 javascript/inspiration.js delete mode 100644 modules/images_history.py delete mode 100644 modules/inspiration.py (limited to 'modules/shared.py') diff --git a/.gitignore b/.gitignore index 8d01bc6a..70660c51 100644 --- a/.gitignore +++ b/.gitignore @@ -29,5 +29,4 @@ notification.mp3 /textual_inversion .vscode /extensions - /inspiration diff --git a/javascript/images_history.js b/javascript/images_history.js deleted file mode 100644 index c9aa76f8..00000000 --- a/javascript/images_history.js +++ /dev/null @@ -1,200 +0,0 @@ -var images_history_click_image = function(){ - if (!this.classList.contains("transform")){ - var gallery = images_history_get_parent_by_class(this, "images_history_cantainor"); - var buttons = gallery.querySelectorAll(".gallery-item"); - var i = 0; - var hidden_list = []; - buttons.forEach(function(e){ - if (e.style.display == "none"){ - hidden_list.push(i); - } - i += 1; - }) - if (hidden_list.length > 0){ - setTimeout(images_history_hide_buttons, 10, hidden_list, gallery); - } - } - images_history_set_image_info(this); -} - -function images_history_disabled_del(){ - gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ - btn.setAttribute('disabled','disabled'); - }); -} - -function images_history_get_parent_by_class(item, class_name){ - var parent = item.parentElement; - while(!parent.classList.contains(class_name)){ - parent = parent.parentElement; - } - return parent; -} - -function images_history_get_parent_by_tagname(item, tagname){ - var parent = item.parentElement; - tagname = tagname.toUpperCase() - while(parent.tagName != tagname){ - parent = parent.parentElement; - } - return parent; -} - -function images_history_hide_buttons(hidden_list, gallery){ - var buttons = gallery.querySelectorAll(".gallery-item"); - var num = 0; - buttons.forEach(function(e){ - if (e.style.display == "none"){ - num += 1; - } - }); - if (num == hidden_list.length){ - setTimeout(images_history_hide_buttons, 10, hidden_list, gallery); - } - for( i in hidden_list){ - buttons[hidden_list[i]].style.display = "none"; - } -} - -function images_history_set_image_info(button){ - var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item"); - var index = -1; - var i = 0; - buttons.forEach(function(e){ - if(e == button){ - index = i; - } - if(e.style.display != "none"){ - i += 1; - } - }); - var gallery = images_history_get_parent_by_class(button, "images_history_cantainor"); - var set_btn = gallery.querySelector(".images_history_set_index"); - var curr_idx = set_btn.getAttribute("img_index", index); - if (curr_idx != index) { - set_btn.setAttribute("img_index", index); - images_history_disabled_del(); - } - set_btn.click(); - -} - -function images_history_get_current_img(tabname, img_index, files){ - return [ - tabname, - gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"), - files - ]; -} - -function images_history_delete(del_num, tabname, image_index){ - image_index = parseInt(image_index); - var tab = gradioApp().getElementById(tabname + '_images_history'); - var set_btn = tab.querySelector(".images_history_set_index"); - var buttons = []; - tab.querySelectorAll(".gallery-item").forEach(function(e){ - if (e.style.display != 'none'){ - buttons.push(e); - } - }); - var img_num = buttons.length / 2; - del_num = Math.min(img_num - image_index, del_num) - if (img_num <= del_num){ - setTimeout(function(tabname){ - gradioApp().getElementById(tabname + '_images_history_renew_page').click(); - }, 30, tabname); - } else { - var next_img - for (var i = 0; i < del_num; i++){ - buttons[image_index + i].style.display = 'none'; - buttons[image_index + i + img_num].style.display = 'none'; - next_img = image_index + i + 1 - } - var bnt; - if (next_img >= img_num){ - btn = buttons[image_index - 1]; - } else { - btn = buttons[next_img]; - } - setTimeout(function(btn){btn.click()}, 30, btn); - } - images_history_disabled_del(); - -} - -function images_history_turnpage(tabname){ - gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled'); - var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item"); - buttons.forEach(function(elem) { - elem.style.display = 'block'; - }) -} - -function images_history_enable_del_buttons(){ - gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ - btn.removeAttribute('disabled'); - }) -} - -function images_history_init(){ - var tabnames = gradioApp().getElementById("images_history_tabnames_list") - if (tabnames){ - images_history_tab_list = tabnames.querySelector("textarea").value.split(",") - for (var i in images_history_tab_list ){ - var tab = images_history_tab_list[i]; - gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor"); - gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index"); - gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button"); - gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); - gradioApp().getElementById(tab + "_images_history_start").setAttribute("style","padding:20px;font-size:25px"); - } - - //preload - if (gradioApp().getElementById("images_history_preload").querySelector("input").checked ){ - var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div"); - tabs_box.setAttribute("id", "images_history_tab"); - var tab_btns = tabs_box.querySelectorAll("button"); - for (var i in images_history_tab_list){ - var tabname = images_history_tab_list[i] - tab_btns[i].setAttribute("tabname", tabname); - tab_btns[i].addEventListener('click', function(){ - var tabs_box = gradioApp().getElementById("images_history_tab"); - if (!tabs_box.classList.contains(this.getAttribute("tabname"))) { - gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click(); - tabs_box.classList.add(this.getAttribute("tabname")) - } - }); - } - tab_btns[0].click() - } - } else { - setTimeout(images_history_init, 500); - } -} - -var images_history_tab_list = ""; -setTimeout(images_history_init, 500); -document.addEventListener("DOMContentLoaded", function() { - var mutationObserver = new MutationObserver(function(m){ - if (images_history_tab_list != ""){ - for (var i in images_history_tab_list ){ - let tabname = images_history_tab_list[i] - var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); - buttons.forEach(function(bnt){ - bnt.addEventListener('click', images_history_click_image, true); - }); - - var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); - if (cls_btn){ - cls_btn.addEventListener('click', function(){ - gradioApp().getElementById(tabname + '_images_history_renew_page').click(); - }, false); - } - - } - } - }); - mutationObserver.observe(gradioApp(), { childList:true, subtree:true }); -}); - - diff --git a/javascript/inspiration.js b/javascript/inspiration.js deleted file mode 100644 index 39844544..00000000 --- a/javascript/inspiration.js +++ /dev/null @@ -1,48 +0,0 @@ -function public_image_index_in_gallery(item, gallery){ - var imgs = gallery.querySelectorAll("img.h-full") - var index; - var i = 0; - imgs.forEach(function(e){ - if (e == item) - index = i; - i += 1; - }); - var all_imgs = gallery.querySelectorAll("img") - if (all_imgs.length > imgs.length){ - var num = imgs.length / 2 - index = (index < num) ? index : (index - num) - } - return index; -} - -function inspiration_selected(name, name_list){ - var btn = gradioApp().getElementById("inspiration_select_button") - return [gradioApp().getElementById("inspiration_select_button").getAttribute("img-index")]; -} - -function inspiration_click_get_button(){ - gradioApp().getElementById("inspiration_get_button").click(); -} - -var inspiration_image_click = function(){ - var index = public_image_index_in_gallery(this, gradioApp().getElementById("inspiration_gallery")); - var btn = gradioApp().getElementById("inspiration_select_button"); - btn.setAttribute("img-index", index); - setTimeout(function(btn){btn.click();}, 10, btn); -} - -document.addEventListener("DOMContentLoaded", function() { - var mutationObserver = new MutationObserver(function(m){ - var gallery = gradioApp().getElementById("inspiration_gallery") - if (gallery) { - var node = gallery.querySelector(".absolute.backdrop-blur.h-full") - if (node) { - node.style.display = "None"; - } - gallery.querySelectorAll('img').forEach(function(e){ - e.onclick = inspiration_image_click - }); - } - }); - mutationObserver.observe( gradioApp(), { childList:true, subtree:true }); -}); diff --git a/modules/images_history.py b/modules/images_history.py deleted file mode 100644 index bc5cf11f..00000000 --- a/modules/images_history.py +++ /dev/null @@ -1,424 +0,0 @@ -import os -import shutil -import time -import hashlib -import gradio -system_bak_path = "webui_log_and_bak" -custom_tab_name = "custom fold" -faverate_tab_name = "favorites" -tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name] -def is_valid_date(date): - try: - time.strptime(date, "%Y%m%d") - return True - except: - return False - -def reduplicative_file_move(src, dst): - def same_name_file(basename, path): - name, ext = os.path.splitext(basename) - f_list = os.listdir(path) - max_num = 0 - for f in f_list: - if len(f) <= len(basename): - continue - f_ext = f[-len(ext):] if len(ext) > 0 else "" - if f[:len(name)] == name and f_ext == ext: - if f[len(name)] == "(" and f[-len(ext)-1] == ")": - number = f[len(name)+1:-len(ext)-1] - if number.isdigit(): - if int(number) > max_num: - max_num = int(number) - return f"{name}({max_num + 1}){ext}" - name = os.path.basename(src) - save_name = os.path.join(dst, name) - if not os.path.exists(save_name): - shutil.move(src, dst) - else: - name = same_name_file(name, dst) - shutil.move(src, os.path.join(dst, name)) - -def traverse_all_files(curr_path, image_list, all_type=False): - try: - f_list = os.listdir(curr_path) - except: - if all_type or (curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt" and curr_path[-4:] != ".csv"): - image_list.append(curr_path) - return image_list - for file in f_list: - file = os.path.join(curr_path, file) - if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"): - pass - elif os.path.isfile(file) and file[-10:].rfind(".") > 0: - image_list.append(file) - else: - image_list = traverse_all_files(file, image_list) - return image_list - -def auto_sorting(dir_name): - bak_path = os.path.join(dir_name, system_bak_path) - if not os.path.exists(bak_path): - os.mkdir(bak_path) - log_file = None - files_list = [] - f_list = os.listdir(dir_name) - for file in f_list: - if file == system_bak_path: - continue - file_path = os.path.join(dir_name, file) - if not is_valid_date(file): - if file[-10:].rfind(".") > 0: - files_list.append(file_path) - else: - files_list = traverse_all_files(file_path, files_list, all_type=True) - - for file in files_list: - date_str = time.strftime("%Y%m%d",time.localtime(os.path.getmtime(file))) - file_path = os.path.dirname(file) - hash_path = hashlib.md5(file_path.encode()).hexdigest() - path = os.path.join(dir_name, date_str, hash_path) - if not os.path.exists(path): - os.makedirs(path) - if log_file is None: - log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a") - log_file.write(f"{hash_path},{file_path}\n") - reduplicative_file_move(file, path) - - date_list = [] - f_list = os.listdir(dir_name) - for f in f_list: - if is_valid_date(f): - date_list.append(f) - elif f == system_bak_path: - continue - else: - try: - reduplicative_file_move(os.path.join(dir_name, f), bak_path) - except: - pass - - today = time.strftime("%Y%m%d",time.localtime(time.time())) - if today not in date_list: - date_list.append(today) - return sorted(date_list, reverse=True) - -def archive_images(dir_name, date_to): - filenames = [] - batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num) - if batch_size <= 0: - batch_size = opts.images_history_num_per_page * 6 - today = time.strftime("%Y%m%d",time.localtime(time.time())) - date_to = today if date_to is None or date_to == "" else date_to - date_to_bak = date_to - if False: #opts.images_history_reconstruct_directory: - date_list = auto_sorting(dir_name) - for date in date_list: - if date <= date_to: - path = os.path.join(dir_name, date) - if date == today and not os.path.exists(path): - continue - filenames = traverse_all_files(path, filenames) - if len(filenames) > batch_size: - break - filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file)) - else: - filenames = traverse_all_files(dir_name, filenames) - total_num = len(filenames) - tmparray = [(os.path.getmtime(file), file) for file in filenames ] - date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400 - filenames = [] - date_list = {date_to:None} - date = time.strftime("%Y%m%d",time.localtime(time.time())) - for t, f in tmparray: - date = time.strftime("%Y%m%d",time.localtime(t)) - date_list[date] = None - if t <= date_stamp: - filenames.append((t, f ,date)) - date_list = sorted(list(date_list.keys()), reverse=True) - sort_array = sorted(filenames, key=lambda x:-x[0]) - if len(sort_array) > batch_size: - date = sort_array[batch_size][2] - filenames = [x[1] for x in sort_array] - else: - date = date_to if len(sort_array) == 0 else sort_array[-1][2] - filenames = [x[1] for x in sort_array] - filenames = [x[1] for x in sort_array if x[2]>= date] - num = len(filenames) - last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000)) - date = date[:4] + "/" + date[4:6] + "/" + date[6:8] - date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8] - load_info = "
" - load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages" - load_info += "
" - _, image_list, _, _, visible_num = get_recent_images(1, 0, filenames) - return ( - date_to, - load_info, - filenames, - 1, - image_list, - "", - "", - visible_num, - last_date_from, - gradio.update(visible=total_num > num) - ) - -def delete_image(delete_num, name, filenames, image_index, visible_num): - if name == "": - return filenames, delete_num - else: - delete_num = int(delete_num) - visible_num = int(visible_num) - image_index = int(image_index) - index = list(filenames).index(name) - i = 0 - new_file_list = [] - for name in filenames: - if i >= index and i < index + delete_num: - if os.path.exists(name): - if visible_num == image_index: - new_file_list.append(name) - i += 1 - continue - print(f"Delete file {name}") - os.remove(name) - visible_num -= 1 - txt_file = os.path.splitext(name)[0] + ".txt" - if os.path.exists(txt_file): - os.remove(txt_file) - else: - print(f"Not exists file {name}") - else: - new_file_list.append(name) - i += 1 - return new_file_list, 1, visible_num - -def save_image(file_name): - if file_name is not None and os.path.exists(file_name): - shutil.copy(file_name, opts.outdir_save) - -def get_recent_images(page_index, step, filenames): - page_index = int(page_index) - num_of_imgs_per_page = int(opts.images_history_num_per_page) - max_page_index = len(filenames) // num_of_imgs_per_page + 1 - page_index = max_page_index if page_index == -1 else page_index + step - page_index = 1 if page_index < 1 else page_index - page_index = max_page_index if page_index > max_page_index else page_index - idx_frm = (page_index - 1) * num_of_imgs_per_page - image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page] - length = len(filenames) - visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page - visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num - return page_index, image_list, "", "", visible_num - -def loac_batch_click(date_to): - if date_to is None: - return time.strftime("%Y%m%d",time.localtime(time.time())), [] - else: - return None, [] -def forward_click(last_date_from, date_to_recorder): - if len(date_to_recorder) == 0: - return None, [] - if last_date_from == date_to_recorder[-1]: - date_to_recorder = date_to_recorder[:-1] - if len(date_to_recorder) == 0: - return None, [] - return date_to_recorder[-1], date_to_recorder[:-1] - -def backward_click(last_date_from, date_to_recorder): - if last_date_from is None or last_date_from == "": - return time.strftime("%Y%m%d",time.localtime(time.time())), [] - if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]: - date_to_recorder.append(last_date_from) - return last_date_from, date_to_recorder - - -def first_page_click(page_index, filenames): - return get_recent_images(1, 0, filenames) - -def end_page_click(page_index, filenames): - return get_recent_images(-1, 0, filenames) - -def prev_page_click(page_index, filenames): - return get_recent_images(page_index, -1, filenames) - -def next_page_click(page_index, filenames): - return get_recent_images(page_index, 1, filenames) - -def page_index_change(page_index, filenames): - return get_recent_images(page_index, 0, filenames) - -def show_image_info(tabname_box, num, page_index, filenames): - file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))] - tm = "
" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "
" - return file, tm, num, file - -def enable_page_buttons(): - return gradio.update(visible=True) - -def change_dir(img_dir, date_to): - warning = None - try: - if os.path.exists(img_dir): - try: - f = os.listdir(img_dir) - except: - warning = f"'{img_dir} is not a directory" - else: - warning = "The directory is not exist" - except: - warning = "The format of the directory is incorrect" - if warning is None: - today = time.strftime("%Y%m%d",time.localtime(time.time())) - return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True) - else: - return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False) - -def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): - custom_dir = False - if tabname == "txt2img": - dir_name = opts.outdir_txt2img_samples - elif tabname == "img2img": - dir_name = opts.outdir_img2img_samples - elif tabname == "extras": - dir_name = opts.outdir_extras_samples - elif tabname == faverate_tab_name: - dir_name = opts.outdir_save - else: - custom_dir = True - dir_name = None - - if not custom_dir: - d = dir_name.split("/") - dir_name = d[0] - for p in d[1:]: - dir_name = os.path.join(dir_name, p) - if not os.path.exists(dir_name): - os.makedirs(dir_name) - - with gr.Column() as page_panel: - with gr.Row(): - with gr.Column(scale=1, visible=not custom_dir) as load_batch_box: - load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True) - with gr.Column(scale=4): - with gr.Row(): - img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir) - with gr.Row(): - with gr.Column(visible=False, scale=1) as batch_panel: - with gr.Row(): - forward = gr.Button('Prev batch') - backward = gr.Button('Next batch') - with gr.Column(scale=3): - load_info = gr.HTML(visible=not custom_dir) - with gr.Row(visible=False) as warning: - warning_box = gr.Textbox("Message", interactive=False) - - with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel: - with gr.Column(scale=2): - with gr.Row(visible=True) as turn_page_buttons: - #date_to = gr.Dropdown(label="Date to") - first_page = gr.Button('First Page') - prev_page = gr.Button('Prev Page') - page_index = gr.Number(value=1, label="Page Index") - next_page = gr.Button('Next Page') - end_page = gr.Button('End Page') - - history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num) - with gr.Row(): - delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") - delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") - - with gr.Column(): - with gr.Row(): - with gr.Column(): - img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6) - gr.HTML("
") - img_file_name = gr.Textbox(value="", label="File Name", interactive=False) - img_file_time= gr.HTML() - with gr.Row(): - if tabname != faverate_tab_name: - save_btn = gr.Button('Collect') - pnginfo_send_to_txt2img = gr.Button('Send to txt2img') - pnginfo_send_to_img2img = gr.Button('Send to img2img') - - - # hiden items - with gr.Row(visible=False): - renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page") - batch_date_to = gr.Textbox(label="Date to") - visible_img_num = gr.Number() - date_to_recorder = gr.State([]) - last_date_from = gr.Textbox() - tabname_box = gr.Textbox(tabname) - image_index = gr.Textbox(value=-1) - set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index") - filenames = gr.State() - all_images_list = gr.State() - hidden = gr.Image(type="pil") - info1 = gr.Textbox() - info2 = gr.Textbox() - - img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info]) - - #change batch - change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel] - - batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output) - batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons]) - batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - - load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder]) - forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) - backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder]) - - - #delete - delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num]) - delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None) - if tabname != faverate_tab_name: - save_btn.click(save_image, inputs=[img_file_name], outputs=None) - - #turn page - gallery_inputs = [page_index, filenames] - gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num] - first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs) - page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) - renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs) - - first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage") - - # other funcitons - set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden]) - img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) - hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) - switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') - switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') - - - -def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict): - global opts; - opts = sys_opts - loads_files_num = int(opts.images_history_num_per_page) - num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num) - if cmp_ops.browse_all_images: - tabs_list.append(custom_tab_name) - with gr.Blocks(analytics_enabled=False) as images_history: - with gr.Tabs() as tabs: - for tab in tabs_list: - with gr.Tab(tab): - with gr.Blocks(analytics_enabled=False) : - show_images_history(gr, opts, tab, run_pnginfo, switch_dict) - gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False) - gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False) - - return images_history diff --git a/modules/inspiration.py b/modules/inspiration.py deleted file mode 100644 index 29cf8297..00000000 --- a/modules/inspiration.py +++ /dev/null @@ -1,193 +0,0 @@ -import os -import random -import gradio -from modules.shared import opts -inspiration_system_path = os.path.join(opts.inspiration_dir, "system") -def read_name_list(file, types=None, keyword=None): - if not os.path.exists(file): - return [] - ret = [] - f = open(file, "r") - line = f.readline() - while len(line) > 0: - line = line.rstrip("\n") - if types is not None: - dirname = os.path.split(line) - if dirname[0] in types and keyword in dirname[1].lower(): - ret.append(line) - else: - ret.append(line) - line = f.readline() - return ret - -def save_name_list(file, name): - name_list = read_name_list(file) - if name not in name_list: - with open(file, "a") as f: - f.write(name + "\n") - -def get_types_list(): - files = os.listdir(opts.inspiration_dir) - types = [] - for x in files: - path = os.path.join(opts.inspiration_dir, x) - if x[0] == ".": - continue - if not os.path.isdir(path): - continue - if path == inspiration_system_path: - continue - types.append(x) - return types - -def get_inspiration_images(source, types, keyword): - keyword = keyword.strip(" ").lower() - get_num = int(opts.inspiration_rows_num * opts.inspiration_cols_num) - if source == "Favorites": - names = read_name_list(os.path.join(inspiration_system_path, "faverites.txt"), types, keyword) - names = random.sample(names, get_num) if len(names) > get_num else names - elif source == "Abandoned": - names = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) - names = random.sample(names, get_num) if len(names) > get_num else names - elif source == "Exclude abandoned": - abandoned = read_name_list(os.path.join(inspiration_system_path, "abandoned.txt"), types, keyword) - all_names = [] - for tp in types: - name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) - all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()] - - if len(all_names) > get_num: - names = [] - while len(names) < get_num: - name = random.choice(all_names) - if name not in abandoned: - names.append(name) - else: - names = all_names - else: - all_names = [] - for tp in types: - name_list = os.listdir(os.path.join(opts.inspiration_dir, tp)) - all_names += [os.path.join(tp, x) for x in name_list if keyword in x.lower()] - names = random.sample(all_names, get_num) if len(all_names) > get_num else all_names - image_list = [] - for a in names: - image_path = os.path.join(opts.inspiration_dir, a) - images = os.listdir(image_path) - if len(images) > 0: - image_list.append((os.path.join(image_path, random.choice(images)), a)) - else: - print(image_path) - return image_list, names - -def select_click(index, name_list): - name = name_list[int(index)] - path = os.path.join(opts.inspiration_dir, name) - images = os.listdir(path) - return name, [os.path.join(path, x) for x in images], "" - -def give_up_click(name): - file = os.path.join(inspiration_system_path, "abandoned.txt") - save_name_list(file, name) - return "Added to abandoned list" - -def collect_click(name): - file = os.path.join(inspiration_system_path, "faverites.txt") - save_name_list(file, name) - return "Added to faverite list" - -def moveout_click(name, source): - if source == "Abandoned": - file = os.path.join(inspiration_system_path, "abandoned.txt") - elif source == "Favorites": - file = os.path.join(inspiration_system_path, "faverites.txt") - else: - return None - name_list = read_name_list(file) - os.remove(file) - with open(file, "a") as f: - for a in name_list: - if a != name: - f.write(a + "\n") - return f"Moved out {name} from {source} list" - -def source_change(source): - if source in ["Abandoned", "Favorites"]: - return gradio.update(visible=True), [] - else: - return gradio.update(visible=False), [] -def add_to_prompt(name, prompt): - name = os.path.basename(name) - return prompt + "," + name - -def clear_keyword(): - return "" - -def ui(gr, opts, txt2img_prompt, img2img_prompt): - with gr.Blocks(analytics_enabled=False) as inspiration: - flag = os.path.exists(opts.inspiration_dir) - if flag: - types = get_types_list() - flag = len(types) > 0 - else: - os.makedirs(opts.inspiration_dir) - if not flag: - gr.HTML(""" -

To activate inspiration function, you need get "inspiration" images first.


- You can create these images by run "Create inspiration images" script in txt2img page,
you can get the artists or art styles list from here
- https://github.com/pharmapsychotic/clip-interrogator/tree/main/data
- download these files, and select these files in the "Create inspiration images" script UI
- There about 6000 artists and art styles in these files.
This takes server hours depending on your GPU type and how many pictures you generate for each artist/style -
I suggest at least four images for each


-

You can also download generated pictures from here:


- https://huggingface.co/datasets/yfszzx/inspiration
- unzip the file to the project directory of webui
- and restart webui, and enjoy the joy of creation!
- """) - return inspiration - if not os.path.exists(inspiration_system_path): - os.mkdir(inspiration_system_path) - with gr.Row(): - with gr.Column(scale=2): - inspiration_gallery = gr.Gallery(show_label=False, elem_id="inspiration_gallery").style(grid=opts.inspiration_cols_num, height='auto') - with gr.Column(scale=1): - types = gr.CheckboxGroup(choices=types, value=types) - with gr.Row(): - source = gr.Dropdown(choices=["All", "Favorites", "Exclude abandoned", "Abandoned"], value="Exclude abandoned", label="Source") - keyword = gr.Textbox("", label="Key word") - get_inspiration = gr.Button("Get inspiration", elem_id="inspiration_get_button") - name = gr.Textbox(show_label=False, interactive=False) - with gr.Row(): - send_to_txt2img = gr.Button('to txt2img') - send_to_img2img = gr.Button('to img2img') - collect = gr.Button('Collect') - give_up = gr.Button("Don't show again") - moveout = gr.Button("Move out", visible=False) - warning = gr.HTML() - style_gallery = gr.Gallery(show_label=False).style(grid=2, height='auto') - - - - with gr.Row(visible=False): - select_button = gr.Button('set button', elem_id="inspiration_select_button") - name_list = gr.State() - - get_inspiration.click(get_inspiration_images, inputs=[source, types, keyword], outputs=[inspiration_gallery, name_list]) - keyword.submit(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) - source.change(source_change, inputs=[source], outputs=[moveout, style_gallery]) - source.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword]) - types.change(fn=clear_keyword, _js="inspiration_click_get_button", inputs=None, outputs=[keyword]) - - select_button.click(select_click, _js="inspiration_selected", inputs=[name, name_list], outputs=[name, style_gallery, warning]) - give_up.click(give_up_click, inputs=[name], outputs=[warning]) - collect.click(collect_click, inputs=[name], outputs=[warning]) - moveout.click(moveout_click, inputs=[name, source], outputs=[warning]) - moveout.click(fn=None, _js="inspiration_click_get_button", inputs=None, outputs=None) - - send_to_txt2img.click(add_to_prompt, inputs=[name, txt2img_prompt], outputs=[txt2img_prompt]) - send_to_img2img.click(add_to_prompt, inputs=[name, img2img_prompt], outputs=[img2img_prompt]) - send_to_txt2img.click(collect_click, inputs=[name], outputs=[warning]) - send_to_img2img.click(collect_click, inputs=[name], outputs=[warning]) - send_to_txt2img.click(None, _js='switch_to_txt2img', inputs=None, outputs=None) - send_to_img2img.click(None, _js="switch_to_img2img_img2img", inputs=None, outputs=None) - return inspiration diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 5bcccd67..66666a56 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,4 +1,3 @@ - callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] @@ -16,7 +15,6 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - for callback in callbacks_ui_tabs: res += callback() or [] diff --git a/modules/shared.py b/modules/shared.py index 0aaaadac..5dfd7927 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -321,21 +321,6 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) -options_templates.update(options_section(('inspiration', "Inspiration"), { - "inspiration_dir": OptionInfo("inspiration", "Directory of inspiration", component_args=hide_dirs), - "inspiration_max_samples": OptionInfo(4, "Maximum number of samples, used to determine which folders to skip when continue running the create script", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}), - "inspiration_rows_num": OptionInfo(4, "Rows of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), - "inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), -})) - -options_templates.update(options_section(('images-history', "Images Browser"), { - #"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"), - "images_history_preload": OptionInfo(False, "Preload images at startup"), - "images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"), - "images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "), - "images_history_grid_num": OptionInfo(6, "Number of grids in each row"), - -})) class Options: data = None diff --git a/modules/ui.py b/modules/ui.py index a73175f5..fa42712e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -49,14 +49,12 @@ from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img import modules.textual_inversion.ui import modules.hypernetworks.ui -import modules.images_history as images_history -import modules.inspiration as inspiration - - # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') +txt2img_paste_fields = [] +img2img_paste_fields = [] if not cmd_opts.share and not cmd_opts.listen: @@ -1193,16 +1191,7 @@ def create_ui(wrap_gradio_gpu_call): inputs=[image], outputs=[html, generation_info, html2], ) - #images history - images_history_switch_dict = { - "fn": modules.generation_parameters_copypaste.connect_paste, - "t2i": txt2img_paste_fields, - "i2i": img2img_paste_fields - } - - browser_interface = images_history.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) - inspiration_interface = inspiration.ui(gr, opts, txt2img_prompt, img2img_prompt) - + with gr.Blocks() as modelmerger_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): @@ -1651,8 +1640,6 @@ Requested path was: {f} (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (inspiration_interface, "Inspiration", "inspiration"), - (browser_interface , "Image Browser", "images_history"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "ti"), ] @@ -1896,6 +1883,7 @@ def load_javascript(raw_response): javascript = f'' scripts_list = modules.scripts.list_scripts("javascript", ".js") + scripts_list += modules.scripts.list_scripts("scripts", ".js") for basedir, filename, path in scripts_list: with open(path, "r", encoding="utf8") as jsfile: javascript += f"\n" -- cgit v1.2.3 From cef1b89aa2e6c7647db7e93a4cd4ec020da3f2da Mon Sep 17 00:00:00 2001 From: yfszzx Date: Mon, 24 Oct 2022 10:10:33 +0800 Subject: remove browser to extension --- modules/script_callbacks.py | 2 ++ modules/shared.py | 1 - modules/ui.py | 2 +- scripts/create_inspiration_images.py | 57 ------------------------------------ 4 files changed, 3 insertions(+), 59 deletions(-) delete mode 100644 scripts/create_inspiration_images.py (limited to 'modules/shared.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 66666a56..f46d3d9a 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,3 +1,4 @@ + callbacks_model_loaded = [] callbacks_ui_tabs = [] callbacks_ui_settings = [] @@ -15,6 +16,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] + for callback in callbacks_ui_tabs: res += callback() or [] diff --git a/modules/shared.py b/modules/shared.py index 5dfd7927..6541e679 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -82,7 +82,6 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) -parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False) cmd_opts = parser.parse_args() restricted_opts = [ diff --git a/modules/ui.py b/modules/ui.py index fa42712e..a32f7259 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1104,7 +1104,7 @@ def create_ui(wrap_gradio_gpu_call): upscaling_crop = gr.Checkbox(label='Crop to fit', value=True) with gr.Group(): - extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers] , value=shared.sd_upscalers[0].name, type="index") + extras_upscaler_1 = gr.Radio(label='Upscaler 1', elem_id="extras_upscaler_1", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") with gr.Group(): extras_upscaler_2 = gr.Radio(label='Upscaler 2', elem_id="extras_upscaler_2", choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") diff --git a/scripts/create_inspiration_images.py b/scripts/create_inspiration_images.py deleted file mode 100644 index 2fd30578..00000000 --- a/scripts/create_inspiration_images.py +++ /dev/null @@ -1,57 +0,0 @@ -import csv, os, shutil -import modules.scripts as scripts -from modules import processing, shared, sd_samplers, images -from modules.processing import Processed -from modules.shared import opts -import gradio -class Script(scripts.Script): - def title(self): - return "Create inspiration images" - - def show(self, is_img2img): - return True - - def ui(self, is_img2img): - file = gradio.Files(label="Artist or styles name list. '.txt' files with one name per line",) - with gradio.Row(): - prefix = gradio.Textbox("a painting in", label="Prompt words before artist or style name", file_count="multiple") - suffix= gradio.Textbox("style", label="Prompt words after artist or style name") - negative_prompt = gradio.Textbox("picture frame, portrait photo", label="Negative Prompt") - with gradio.Row(): - batch_size = gradio.Number(1, label="Batch size") - batch_count = gradio.Number(2, label="Batch count") - return [batch_size, batch_count, prefix, suffix, negative_prompt, file] - - def run(self, p, batch_size, batch_count, prefix, suffix, negative_prompt, files): - p.batch_size = int(batch_size) - p.n_iterint = int(batch_count) - p.negative_prompt = negative_prompt - p.do_not_save_samples = True - p.do_not_save_grid = True - for file in files: - tp = file.orig_name.split(".")[0] - print(tp) - path = os.path.join(opts.inspiration_dir, tp) - if not os.path.exists(path): - os.makedirs(path) - f = open(file.name, "r") - line = f.readline() - while len(line) > 0: - name = line.rstrip("\n").split(",")[0] - line = f.readline() - artist_path = os.path.join(path, name) - if not os.path.exists(artist_path): - os.mkdir(artist_path) - if len(os.listdir(artist_path)) >= opts.inspiration_max_samples: - continue - p.prompt = f"{prefix} {name} {suffix}" - print(p.prompt) - processed = processing.process_images(p) - for img in processed.images: - i = 0 - filename = os.path.join(artist_path, format(0, "03d") + ".jpg") - while os.path.exists(filename): - i += 1 - filename = os.path.join(artist_path, format(i, "03d") + ".jpg") - img.save(filename, quality=80) - return processed -- cgit v1.2.3 From 3be6b29d81408d2adb741bff5b11c80214aa621e Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 24 Oct 2022 15:14:34 +0900 Subject: indent=4 config.json --- modules/shared.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 6541e679..d6ddfe59 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -348,7 +348,7 @@ class Options: def save(self, filename): with open(filename, "w", encoding="utf8") as file: - json.dump(self.data, file) + json.dump(self.data, file, indent=4) def same_type(self, x, y): if x is None or y is None: -- cgit v1.2.3 From 8da1bd48bf9d0411cd9ba87b8d9220743cb5807e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 24 Oct 2022 14:03:58 +0300 Subject: add an option to skip adding number to filenames when saving. rework filename pattern function go through the pattern once and not calculate any of replacements until they are actually encountered in the pattern. --- modules/images.py | 250 ++++++++++++++++++++++++++++-------------------------- modules/shared.py | 8 +- 2 files changed, 135 insertions(+), 123 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/images.py b/modules/images.py index a9b1330d..848ede75 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,4 +1,7 @@ import datetime +import sys +import traceback + import pytz import io import math @@ -274,10 +277,15 @@ invalid_filename_chars = '<>:"/\\|?*\n' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') +re_pattern = re.compile(r"([^\[\]]+|\[([^]]+)]|[\[\]]*)") +re_pattern_arg = re.compile(r"(.*)<([^>]*)>$") max_filename_part_length = 128 def sanitize_filename_part(text, replace_spaces=True): + if text is None: + return None + if replace_spaces: text = text.replace(' ', '_') @@ -287,49 +295,103 @@ def sanitize_filename_part(text, replace_spaces=True): return text -def apply_filename_pattern(x, p, seed, prompt): - max_prompt_words = opts.directories_max_prompt_words - - if seed is not None: - x = re.sub(r'\[seed]', str(seed), x, flags=re.IGNORECASE) - - if p is not None: - x = re.sub(r'\[steps]', str(p.steps), x, flags=re.IGNORECASE) - x = re.sub(r'\[cfg]', str(p.cfg_scale), x, flags=re.IGNORECASE) - x = re.sub(r'\[width]', str(p.width), x, flags=re.IGNORECASE) - x = re.sub(r'\[height]', str(p.height), x, flags=re.IGNORECASE) - x = re.sub(r'\[styles]', sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]) or "None", replace_spaces=False), x, flags=re.IGNORECASE) - x = re.sub(r'\[sampler]', sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False), x, flags=re.IGNORECASE) - - x = re.sub(r'\[model_hash]', getattr(p, "sd_model_hash", shared.sd_model.sd_model_hash), x, flags=re.IGNORECASE) - current_time = datetime.datetime.now() - x = re.sub(r'\[date]', current_time.strftime('%Y-%m-%d'), x, flags=re.IGNORECASE) - x = replace_datetime(x, current_time) # replace [datetime], [datetime], [datetime